diff --git a/.gitattributes b/.gitattributes
index a6344aac8c09253b3b630fb776ae94478aa0275b..cb5a34493dd6a0dbed9c02f2f6bc88e3e1ec3e3a 100644
--- a/.gitattributes
+++ b/.gitattributes
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
+examples/example2/mask.mp4 filter=lfs diff=lfs merge=lfs -text
+examples/example3/video.mp4 filter=lfs diff=lfs merge=lfs -text
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..7a4a3ea2424c09fbe48d455aed1eaa94d9124835
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
\ No newline at end of file
diff --git a/ORIGINAL_README.md b/ORIGINAL_README.md
new file mode 100644
index 0000000000000000000000000000000000000000..bc7e5c167422ac28bf4702f7121a6b03528f415b
--- /dev/null
+++ b/ORIGINAL_README.md
@@ -0,0 +1,171 @@
+
+
+
DiffuEraser: A Diffusion Model for Video Inpainting
+
+
+ Xiaowen Li
+ Haolan Xue
+ Peiran Ren
+ Liefeng Bo
+
+
+ Tongyi Lab, Alibaba Group
+
+
+
+ TECHNICAL REPORT
+
+
+
+
+
+
+
+
+
+DiffuEraser is a diffusion model for video inpainting, which outperforms state-of-the-art model Propainter in both content completeness and temporal consistency while maintaining acceptable efficiency.
+
+---
+
+
+## Update Log
+- *2025.01.20*: Release inference code.
+
+
+## TODO
+- [ ] Release training code.
+- [ ] Release HuggingFace/ModelScope demo.
+- [ ] Release gradio demo.
+
+
+## Results
+More results will be displayed on the project page.
+
+https://github.com/user-attachments/assets/b59d0b88-4186-4531-8698-adf6e62058f8
+
+
+
+
+## Method Overview
+Our network is inspired by [BrushNet](https://github.com/TencentARC/BrushNet) and [Animatediff](https://github.com/guoyww/AnimateDiff). The architecture comprises the primary `denoising UNet` and an auxiliary `BrushNet branch`. Features extracted by BrushNet branch are integrated into the denoising UNet layer by layer after a zero convolution block. The denoising UNet performs the denoising process to generate the final output. To enhance temporal consistency, `temporal attention` mechanisms are incorporated following both self-attention and cross-attention layers. After denoising, the generated images are blended with the input masked images using blurred masks.
+
+
+
+We incorporate `prior` information to provide initialization and weak conditioning, which helps mitigate noisy artifacts and suppress hallucinations.
+Additionally, to improve temporal consistency during long-sequence inference, we expand the `temporal receptive fields` of both the prior model and DiffuEraser, and further enhance consistency by leveraging the temporal smoothing capabilities of Video Diffusion Models. Please read the paper for details.
+
+
+## Getting Started
+
+#### Installation
+
+1. Clone Repo
+
+ ```bash
+ git clone https://github.com/lixiaowen-xw/DiffuEraser.git
+ ```
+
+2. Create Conda Environment and Install Dependencies
+
+ ```bash
+ # create new anaconda env
+ conda create -n diffueraser python=3.9.19
+ conda activate diffueraser
+ # install python dependencies
+ pip install -r requirements.txt
+ ```
+
+#### Prepare pretrained models
+Weights will be placed under the `./weights` directory.
+1. Download our pretrained models from [Hugging Face](https://huggingface.co/lixiaowen/diffuEraser) or [ModelScope](https://www.modelscope.cn/xingzi/diffuEraser.git) to the `weights` folder.
+2. Download pretrained weight of based models and other components:
+ - [stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) . The full folder size is over 30 GB. If you want to save storage space, you can download only the necessary files: feature_extractor, model_index.json, safety_checker, scheduler, text_encoder, and tokenizer,about 4GB.
+ - [PCM_Weights](https://huggingface.co/wangfuyun/PCM_Weights)
+ - [propainter](https://github.com/sczhou/ProPainter/releases/tag/v0.1.0)
+ - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse)
+
+
+The directory structure will be arranged as:
+```
+weights
+ |- diffuEraser
+ |-brushnet
+ |-unet_main
+ |- stable-diffusion-v1-5
+ |-feature_extractor
+ |-...
+ |- PCM_Weights
+ |-sd15
+ |- propainter
+ |-ProPainter.pth
+ |-raft-things.pth
+ |-recurrent_flow_completion.pth
+ |- sd-vae-ft-mse
+ |-diffusion_pytorch_model.bin
+ |-...
+ |- README.md
+```
+
+#### Main Inference
+We provide some examples in the [`examples`](./examples) folder.
+Run the following commands to try it out:
+```shell
+cd DiffuEraser
+python run_diffueraser.py
+```
+The results will be saved in the `results` folder.
+To test your own videos, please replace the `input_video` and `input_mask` in run_diffueraser.py . The first inference may take a long time.
+
+The `frame rate` of input_video and input_mask needs to be consistent. We currently only support `mp4 video` as input intead of split frames, you can convert frames to video using ffmepg:
+```shell
+ffmpeg -i image%03d.jpg -c:v libx264 -r 25 output.mp4
+```
+Notice: Do not convert the frame rate of mask video if it is not consitent with that of the input video, which would lead to errors due to misalignment.
+
+
+Blow shows the estimated GPU memory requirements and inference time for different resolution:
+
+| Resolution | Gpu Memeory | Inference Time(250f(~10s), L20) |
+| :--------- | :---------: | :-----------------------------: |
+| 1280 x 720 | 33G | 314s |
+| 960 x 540 | 20G | 175s |
+| 640 x 360 | 12G | 92s |
+
+
+## Citation
+
+ If you find our repo useful for your research, please consider citing our paper:
+
+ ```bibtex
+ @misc{li2025diffueraserdiffusionmodelvideo,
+ title={DiffuEraser: A Diffusion Model for Video Inpainting},
+ author={Xiaowen Li and Haolan Xue and Peiran Ren and Liefeng Bo},
+ year={2025},
+ eprint={2501.10018},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV},
+ url={https://arxiv.org/abs/2501.10018},
+}
+ ```
+
+
+## License
+This repository uses [Propainter](https://github.com/sczhou/ProPainter) as the prior model. Users must comply with [Propainter's license](https://github.com/sczhou/ProPainter/blob/main/LICENSE) when using this code. Or you can use other model to replace it.
+
+This project is licensed under the [Apache License Version 2.0](./LICENSE) except for the third-party components listed below.
+
+
+## Acknowledgement
+
+This code is based on [BrushNet](https://github.com/TencentARC/BrushNet), [Propainter](https://github.com/sczhou/ProPainter) and [Animatediff](https://github.com/guoyww/AnimateDiff). The example videos come from [Pexels](https://www.pexels.com/), [DAVIS](https://davischallenge.org/), [SA-V](https://ai.meta.com/datasets/segment-anything-video) and [DanceTrack](https://dancetrack.github.io/). Thanks for their awesome works.
+
+
diff --git a/assets/DiffuEraser_pipeline.png b/assets/DiffuEraser_pipeline.png
new file mode 100644
index 0000000000000000000000000000000000000000..ab87700db6ffc786aee5f205cb18c1eed158bb4f
Binary files /dev/null and b/assets/DiffuEraser_pipeline.png differ
diff --git a/diffueraser/diffueraser.py b/diffueraser/diffueraser.py
new file mode 100644
index 0000000000000000000000000000000000000000..108e72a0f360e375aaa1e0fe74ca71d3a90260cc
--- /dev/null
+++ b/diffueraser/diffueraser.py
@@ -0,0 +1,432 @@
+import gc
+import copy
+import cv2
+import os
+import numpy as np
+import torch
+import torchvision
+from einops import repeat
+from PIL import Image, ImageFilter
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ UniPCMultistepScheduler,
+ LCMScheduler,
+)
+from diffusers.schedulers import TCDScheduler
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.utils.torch_utils import randn_tensor
+from transformers import AutoTokenizer, PretrainedConfig
+
+from libs.unet_motion_model import MotionAdapter, UNetMotionModel
+from libs.brushnet_CA import BrushNetModel
+from libs.unet_2d_condition import UNet2DConditionModel
+from diffueraser.pipeline_diffueraser import StableDiffusionDiffuEraserPipeline
+
+
+checkpoints = {
+ "2-Step": ["pcm_{}_smallcfg_2step_converted.safetensors", 2, 0.0],
+ "4-Step": ["pcm_{}_smallcfg_4step_converted.safetensors", 4, 0.0],
+ "8-Step": ["pcm_{}_smallcfg_8step_converted.safetensors", 8, 0.0],
+ "16-Step": ["pcm_{}_smallcfg_16step_converted.safetensors", 16, 0.0],
+ "Normal CFG 4-Step": ["pcm_{}_normalcfg_4step_converted.safetensors", 4, 7.5],
+ "Normal CFG 8-Step": ["pcm_{}_normalcfg_8step_converted.safetensors", 8, 7.5],
+ "Normal CFG 16-Step": ["pcm_{}_normalcfg_16step_converted.safetensors", 16, 7.5],
+ "LCM-Like LoRA": [
+ "pcm_{}_lcmlike_lora_converted.safetensors",
+ 4,
+ 0.0,
+ ],
+}
+
+def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
+ text_encoder_config = PretrainedConfig.from_pretrained(
+ pretrained_model_name_or_path,
+ subfolder="text_encoder",
+ revision=revision,
+ )
+ model_class = text_encoder_config.architectures[0]
+
+ if model_class == "CLIPTextModel":
+ from transformers import CLIPTextModel
+
+ return CLIPTextModel
+ elif model_class == "RobertaSeriesModelWithTransformation":
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
+
+ return RobertaSeriesModelWithTransformation
+ else:
+ raise ValueError(f"{model_class} is not supported.")
+
+def resize_frames(frames, size=None):
+ if size is not None:
+ out_size = size
+ process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
+ frames = [f.resize(process_size) for f in frames]
+ else:
+ out_size = frames[0].size
+ process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
+ if not out_size == process_size:
+ frames = [f.resize(process_size) for f in frames]
+
+ return frames
+
+def read_mask(validation_mask, fps, n_total_frames, img_size, mask_dilation_iter, frames):
+ cap = cv2.VideoCapture(validation_mask)
+ if not cap.isOpened():
+ print("Error: Could not open mask video.")
+ exit()
+ mask_fps = cap.get(cv2.CAP_PROP_FPS)
+ if mask_fps != fps:
+ cap.release()
+ raise ValueError("The frame rate of all input videos needs to be consistent.")
+
+ masks = []
+ masked_images = []
+ idx = 0
+ while True:
+ ret, frame = cap.read()
+ if not ret:
+ break
+ if(idx >= n_total_frames):
+ break
+ mask = Image.fromarray(frame[...,::-1]).convert('L')
+ if mask.size != img_size:
+ mask = mask.resize(img_size, Image.NEAREST)
+ mask = np.asarray(mask)
+ m = np.array(mask > 0).astype(np.uint8)
+ m = cv2.erode(m,
+ cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)),
+ iterations=1)
+ m = cv2.dilate(m,
+ cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)),
+ iterations=mask_dilation_iter)
+
+ mask = Image.fromarray(m * 255)
+ masks.append(mask)
+
+ masked_image = np.array(frames[idx])*(1-(np.array(mask)[:,:,np.newaxis].astype(np.float32)/255))
+ masked_image = Image.fromarray(masked_image.astype(np.uint8))
+ masked_images.append(masked_image)
+
+ idx += 1
+ cap.release()
+
+ return masks, masked_images
+
+def read_priori(priori, fps, n_total_frames, img_size):
+ cap = cv2.VideoCapture(priori)
+ if not cap.isOpened():
+ print("Error: Could not open video.")
+ exit()
+ priori_fps = cap.get(cv2.CAP_PROP_FPS)
+ if priori_fps != fps:
+ cap.release()
+ raise ValueError("The frame rate of all input videos needs to be consistent.")
+
+ prioris=[]
+ idx = 0
+ while True:
+ ret, frame = cap.read()
+ if not ret:
+ break
+ if(idx >= n_total_frames):
+ break
+ img = Image.fromarray(frame[...,::-1])
+ if img.size != img_size:
+ img = img.resize(img_size)
+ prioris.append(img)
+ idx += 1
+ cap.release()
+
+ os.remove(priori) # remove priori
+
+ return prioris
+
+def read_video(validation_image, video_length, nframes, max_img_size):
+ vframes, aframes, info = torchvision.io.read_video(filename=validation_image, pts_unit='sec', end_pts=video_length) # RGB
+ fps = info['video_fps']
+ n_total_frames = int(video_length * fps)
+ n_clip = int(np.ceil(n_total_frames/nframes))
+
+ frames = list(vframes.numpy())[:n_total_frames]
+ frames = [Image.fromarray(f) for f in frames]
+ max_size = max(frames[0].size)
+ if(max_size<256):
+ raise ValueError("The resolution of the uploaded video must be larger than 256x256.")
+ if(max_size>4096):
+ raise ValueError("The resolution of the uploaded video must be smaller than 4096x4096.")
+ if max_size>max_img_size:
+ ratio = max_size/max_img_size
+ ratio_size = (int(frames[0].size[0]/ratio),int(frames[0].size[1]/ratio))
+ img_size = (ratio_size[0]-ratio_size[0]%8, ratio_size[1]-ratio_size[1]%8)
+ resize_flag=True
+ elif (frames[0].size[0]%8==0) and (frames[0].size[1]%8==0):
+ img_size = frames[0].size
+ resize_flag=False
+ else:
+ ratio_size = frames[0].size
+ img_size = (ratio_size[0]-ratio_size[0]%8, ratio_size[1]-ratio_size[1]%8)
+ resize_flag=True
+ if resize_flag:
+ frames = resize_frames(frames, img_size)
+ img_size = frames[0].size
+
+ return frames, fps, img_size, n_clip, n_total_frames
+
+
+class DiffuEraser:
+ def __init__(
+ self, device, base_model_path, vae_path, diffueraser_path, revision=None,
+ ckpt="Normal CFG 4-Step", mode="sd15", loaded=None):
+ self.device = device
+
+ ## load model
+ self.vae = AutoencoderKL.from_pretrained(vae_path)
+ self.noise_scheduler = DDPMScheduler.from_pretrained(base_model_path,
+ subfolder="scheduler",
+ prediction_type="v_prediction",
+ timestep_spacing="trailing",
+ rescale_betas_zero_snr=True
+ )
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ base_model_path,
+ subfolder="tokenizer",
+ use_fast=False,
+ )
+ text_encoder_cls = import_model_class_from_model_name_or_path(base_model_path,revision)
+ self.text_encoder = text_encoder_cls.from_pretrained(
+ base_model_path, subfolder="text_encoder"
+ )
+ self.brushnet = BrushNetModel.from_pretrained(diffueraser_path, subfolder="brushnet")
+ self.unet_main = UNetMotionModel.from_pretrained(
+ diffueraser_path, subfolder="unet_main",
+ )
+
+ ## set pipeline
+ self.pipeline = StableDiffusionDiffuEraserPipeline.from_pretrained(
+ base_model_path,
+ vae=self.vae,
+ text_encoder=self.text_encoder,
+ tokenizer=self.tokenizer,
+ unet=self.unet_main,
+ brushnet=self.brushnet
+ ).to(self.device, torch.float16)
+ self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)
+ self.pipeline.set_progress_bar_config(disable=True)
+
+ self.noise_scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config)
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
+
+ ## use PCM
+ self.ckpt = ckpt
+ PCM_ckpts = checkpoints[ckpt][0].format(mode)
+ self.guidance_scale = checkpoints[ckpt][2]
+ if loaded != (ckpt + mode):
+ self.pipeline.load_lora_weights(
+ "weights/PCM_Weights", weight_name=PCM_ckpts, subfolder=mode
+ )
+ loaded = ckpt + mode
+
+ if ckpt == "LCM-Like LoRA":
+ self.pipeline.scheduler = LCMScheduler()
+ else:
+ self.pipeline.scheduler = TCDScheduler(
+ num_train_timesteps=1000,
+ beta_start=0.00085,
+ beta_end=0.012,
+ beta_schedule="scaled_linear",
+ timestep_spacing="trailing",
+ )
+ self.num_inference_steps = checkpoints[ckpt][1]
+ self.guidance_scale = 0
+
+ def forward(self, validation_image, validation_mask, priori, output_path,
+ max_img_size = 1280, video_length=2, mask_dilation_iter=4,
+ nframes=22, seed=None, revision = None, guidance_scale=None, blended=True):
+ validation_prompt = "" #
+ guidance_scale_final = self.guidance_scale if guidance_scale==None else guidance_scale
+
+ if (max_img_size<256 or max_img_size>1920):
+ raise ValueError("The max_img_size must be larger than 256, smaller than 1920.")
+
+ ################ read input video ################
+ frames, fps, img_size, n_clip, n_total_frames = read_video(validation_image, video_length, nframes, max_img_size)
+ video_len = len(frames)
+
+ ################ read mask ################
+ validation_masks_input, validation_images_input = read_mask(validation_mask, fps, video_len, img_size, mask_dilation_iter, frames)
+
+ ################ read priori ################
+ prioris = read_priori(priori, fps, n_total_frames, img_size)
+
+ ## recheck
+ n_total_frames = min(min(len(frames), len(validation_masks_input)), len(prioris))
+ if(n_total_frames<22):
+ raise ValueError("The effective video duration is too short. Please make sure that the number of frames of video, mask, and priori is at least greater than 22 frames.")
+ validation_masks_input = validation_masks_input[:n_total_frames]
+ validation_images_input = validation_images_input[:n_total_frames]
+ frames = frames[:n_total_frames]
+ prioris = prioris[:n_total_frames]
+
+ prioris = resize_frames(prioris)
+ validation_masks_input = resize_frames(validation_masks_input)
+ validation_images_input = resize_frames(validation_images_input)
+ resized_frames = resize_frames(frames)
+
+ ##############################################
+ # DiffuEraser inference
+ ##############################################
+ print("DiffuEraser inference...")
+ if seed is None:
+ generator = None
+ else:
+ generator = torch.Generator(device=self.device).manual_seed(seed)
+
+ ## random noise
+ real_video_length = len(validation_images_input)
+ tar_width, tar_height = validation_images_input[0].size
+ shape = (
+ nframes,
+ 4,
+ tar_height//8,
+ tar_width//8
+ )
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet_main is not None:
+ prompt_embeds_dtype = self.unet_main.dtype
+ else:
+ prompt_embeds_dtype = torch.float16
+ noise_pre = randn_tensor(shape, device=torch.device(self.device), dtype=prompt_embeds_dtype, generator=generator)
+ noise = repeat(noise_pre, "t c h w->(repeat t) c h w", repeat=n_clip)[:real_video_length,...]
+
+ ################ prepare priori ################
+ images_preprocessed = []
+ for image in prioris:
+ image = self.image_processor.preprocess(image, height=tar_height, width=tar_width).to(dtype=torch.float32)
+ image = image.to(device=torch.device(self.device), dtype=torch.float16)
+ images_preprocessed.append(image)
+ pixel_values = torch.cat(images_preprocessed)
+
+ with torch.no_grad():
+ pixel_values = pixel_values.to(dtype=torch.float16)
+ latents = []
+ num=4
+ for i in range(0, pixel_values.shape[0], num):
+ latents.append(self.vae.encode(pixel_values[i : i + num]).latent_dist.sample())
+ latents = torch.cat(latents, dim=0)
+ latents = latents * self.vae.config.scaling_factor #[(b f), c1, h, w], c1=4
+ torch.cuda.empty_cache()
+ timesteps = torch.tensor([0], device=self.device)
+ timesteps = timesteps.long()
+
+ validation_masks_input_ori = copy.deepcopy(validation_masks_input)
+ resized_frames_ori = copy.deepcopy(resized_frames)
+ ################ Pre-inference ################
+ if n_total_frames > nframes*2: ## do pre-inference only when number of input frames is larger than nframes*2
+ ## sample
+ step = n_total_frames / nframes
+ sample_index = [int(i * step) for i in range(nframes)]
+ sample_index = sample_index[:22]
+ validation_masks_input_pre = [validation_masks_input[i] for i in sample_index]
+ validation_images_input_pre = [validation_images_input[i] for i in sample_index]
+ latents_pre = torch.stack([latents[i] for i in sample_index])
+
+ ## add proiri
+ noisy_latents_pre = self.noise_scheduler.add_noise(latents_pre, noise_pre, timesteps)
+ latents_pre = noisy_latents_pre
+
+ with torch.no_grad():
+ latents_pre_out = self.pipeline(
+ num_frames=nframes,
+ prompt=validation_prompt,
+ images=validation_images_input_pre,
+ masks=validation_masks_input_pre,
+ num_inference_steps=self.num_inference_steps,
+ generator=generator,
+ guidance_scale=guidance_scale_final,
+ latents=latents_pre,
+ ).latents
+ torch.cuda.empty_cache()
+
+ def decode_latents(latents, weight_dtype):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ video = []
+ for t in range(latents.shape[0]):
+ video.append(self.vae.decode(latents[t:t+1, ...].to(weight_dtype)).sample)
+ video = torch.concat(video, dim=0)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ video = video.float()
+ return video
+ with torch.no_grad():
+ video_tensor_temp = decode_latents(latents_pre_out, weight_dtype=torch.float16)
+ images_pre_out = self.image_processor.postprocess(video_tensor_temp, output_type="pil")
+ torch.cuda.empty_cache()
+
+ ## replace input frames with updated frames
+ black_image = Image.new('L', validation_masks_input[0].size, color=0)
+ for i,index in enumerate(sample_index):
+ latents[index] = latents_pre_out[i]
+ validation_masks_input[index] = black_image
+ validation_images_input[index] = images_pre_out[i]
+ resized_frames[index] = images_pre_out[i]
+ else:
+ latents_pre_out=None
+ sample_index=None
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ ################ Frame-by-frame inference ################
+ ## add priori
+ noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
+ latents = noisy_latents
+ with torch.no_grad():
+ images = self.pipeline(
+ num_frames=nframes,
+ prompt=validation_prompt,
+ images=validation_images_input,
+ masks=validation_masks_input,
+ num_inference_steps=self.num_inference_steps,
+ generator=generator,
+ guidance_scale=guidance_scale_final,
+ latents=latents,
+ ).frames
+ images = images[:real_video_length]
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ ################ Compose ################
+ binary_masks = validation_masks_input_ori
+ mask_blurreds = []
+ if blended:
+ # blur, you can adjust the parameters for better performance
+ for i in range(len(binary_masks)):
+ mask_blurred = cv2.GaussianBlur(np.array(binary_masks[i]), (21, 21), 0)/255.
+ binary_mask = 1-(1-np.array(binary_masks[i])/255.) * (1-mask_blurred)
+ mask_blurreds.append(Image.fromarray((binary_mask*255).astype(np.uint8)))
+ binary_masks = mask_blurreds
+
+ comp_frames = []
+ for i in range(len(images)):
+ mask = np.expand_dims(np.array(binary_masks[i]),2).repeat(3, axis=2).astype(np.float32)/255.
+ img = (np.array(images[i]).astype(np.uint8) * mask \
+ + np.array(resized_frames_ori[i]).astype(np.uint8) * (1 - mask)).astype(np.uint8)
+ comp_frames.append(Image.fromarray(img))
+
+ default_fps = fps
+ writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"),
+ default_fps, comp_frames[0].size)
+ for f in range(real_video_length):
+ img = np.array(comp_frames[f]).astype(np.uint8)
+ writer.write(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
+ writer.release()
+ ################################
+
+ return output_path
+
+
+
+
diff --git a/diffueraser/pipeline_diffueraser.py b/diffueraser/pipeline_diffueraser.py
new file mode 100644
index 0000000000000000000000000000000000000000..db4e230e8a7c627651cd8319a31c71823f20f390
--- /dev/null
+++ b/diffueraser/pipeline_diffueraser.py
@@ -0,0 +1,1349 @@
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Tuple, Union
+import numpy as np
+import PIL.Image
+from einops import rearrange, repeat
+from dataclasses import dataclass
+import copy
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
+
+from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
+from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, ImageProjection
+from diffusers.models.lora import adjust_lora_scale_text_encoder
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ USE_PEFT_BACKEND,
+ deprecate,
+ logging,
+ replace_example_docstring,
+ scale_lora_layers,
+ unscale_lora_layers,
+ BaseOutput
+)
+from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin
+from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
+from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
+from diffusers import (
+ AutoencoderKL,
+ DDPMScheduler,
+ UniPCMultistepScheduler,
+)
+
+from libs.unet_2d_condition import UNet2DConditionModel
+from libs.brushnet_CA import BrushNetModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
+def retrieve_timesteps(
+ scheduler,
+ num_inference_steps: Optional[int] = None,
+ device: Optional[Union[str, torch.device]] = None,
+ timesteps: Optional[List[int]] = None,
+ **kwargs,
+):
+ """
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
+
+ Args:
+ scheduler (`SchedulerMixin`):
+ The scheduler to get timesteps from.
+ num_inference_steps (`int`):
+ The number of diffusion steps used when generating samples with a pre-trained model. If used,
+ `timesteps` must be `None`.
+ device (`str` or `torch.device`, *optional*):
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
+ timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
+ must be `None`.
+
+ Returns:
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
+ second element is the number of inference steps.
+ """
+ if timesteps is not None:
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
+ if not accepts_timesteps:
+ raise ValueError(
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
+ f" timestep schedules. Please check whether you are using the correct scheduler."
+ )
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ num_inference_steps = len(timesteps)
+ else:
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
+ timesteps = scheduler.timesteps
+ return timesteps, num_inference_steps
+
+def get_frames_context_swap(total_frames=192, overlap=4, num_frames_per_clip=24):
+ if total_framesnum_frames_per_clip:
+ ## [0,num_frames_per_clip-1], [num_frames_per_clip, 2*num_frames_per_clip-1]....
+ for k in range(0,n-num_frames_per_clip,num_frames_per_clip-overlap):
+ context_list.append(sample_interval[k:k+num_frames_per_clip])
+ if k+num_frames_per_clip < n and i==1:
+ context_list.append(sample_interval[n-num_frames_per_clip:n])
+ context_list_swap.append(sample_interval[0:num_frames_per_clip])
+ for k in range(num_frames_per_clip//2, n-num_frames_per_clip, num_frames_per_clip-overlap):
+ context_list_swap.append(sample_interval[k:k+num_frames_per_clip])
+ if k+num_frames_per_clip < n and i==1:
+ context_list_swap.append(sample_interval[n-num_frames_per_clip:n])
+ if n==num_frames_per_clip:
+ context_list.append(sample_interval[n-num_frames_per_clip:n])
+ context_list_swap.append(sample_interval[n-num_frames_per_clip:n])
+ return context_list, context_list_swap
+
+@dataclass
+class DiffuEraserPipelineOutput(BaseOutput):
+ frames: Union[torch.Tensor, np.ndarray]
+ latents: Union[torch.Tensor, np.ndarray]
+
+class StableDiffusionDiffuEraserPipeline(
+ DiffusionPipeline,
+ StableDiffusionMixin,
+ TextualInversionLoaderMixin,
+ LoraLoaderMixin,
+ IPAdapterMixin,
+ FromSingleFileMixin,
+):
+ r"""
+ Pipeline for video inpainting using Video Diffusion Model with BrushNet guidance.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
+
+ The pipeline also inherits the following loading methods:
+ - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
+ - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files
+ - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
+ text_encoder ([`~transformers.CLIPTextModel`]):
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
+ tokenizer ([`~transformers.CLIPTokenizer`]):
+ A `CLIPTokenizer` to tokenize text.
+ unet ([`UNet2DConditionModel`]):
+ A `UNet2DConditionModel` to denoise the encoded image latents.
+ brushnet ([`BrushNetModel`]`):
+ Provides additional conditioning to the `unet` during the denoising process.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details
+ about a model's potential harms.
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
+ A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`.
+ """
+
+ model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae"
+ _optional_components = ["safety_checker", "feature_extractor", "image_encoder"]
+ _exclude_from_cpu_offload = ["safety_checker"]
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ brushnet: BrushNetModel,
+ scheduler: KarrasDiffusionSchedulers,
+ safety_checker: StableDiffusionSafetyChecker,
+ feature_extractor: CLIPImageProcessor,
+ image_encoder: CLIPVisionModelWithProjection = None,
+ requires_safety_checker: bool = True,
+ ):
+ super().__init__()
+
+ if safety_checker is None and requires_safety_checker:
+ logger.warning(
+ f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
+ " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
+ " results in services or applications open to the public. Both the diffusers team and Hugging Face"
+ " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
+ " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
+ " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
+ )
+
+ if safety_checker is not None and feature_extractor is None:
+ raise ValueError(
+ "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety"
+ " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
+ )
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ brushnet=brushnet,
+ scheduler=scheduler,
+ safety_checker=safety_checker,
+ feature_extractor=feature_extractor,
+ image_encoder=image_encoder,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
+ self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ **kwargs,
+ ):
+ deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple."
+ deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False)
+
+ prompt_embeds_tuple = self.encode_prompt(
+ prompt=prompt,
+ device=device,
+ num_images_per_prompt=num_images_per_prompt,
+ do_classifier_free_guidance=do_classifier_free_guidance,
+ negative_prompt=negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=lora_scale,
+ **kwargs,
+ )
+
+ # concatenate for backwards comp
+ prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt
+ def encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ lora_scale: Optional[float] = None,
+ clip_skip: Optional[int] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ lora_scale (`float`, *optional*):
+ A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ """
+ # set lora scale so that monkey patched LoRA
+ # function of text encoder can correctly access it
+ if lora_scale is not None and isinstance(self, LoraLoaderMixin):
+ self._lora_scale = lora_scale
+
+ # dynamically adjust the LoRA scale
+ if not USE_PEFT_BACKEND:
+ adjust_lora_scale_text_encoder(self.text_encoder, lora_scale)
+ else:
+ scale_lora_layers(self.text_encoder, lora_scale)
+
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ if clip_skip is None:
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask)
+ prompt_embeds = prompt_embeds[0]
+ else:
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True
+ )
+ # Access the `hidden_states` first, that contains a tuple of
+ # all the hidden states from the encoder layers. Then index into
+ # the tuple to access the hidden states from the desired layer.
+ prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)]
+ # We also need to apply the final LayerNorm here to not mess with the
+ # representations. The `last_hidden_states` that we typically use for
+ # obtaining the final prompt representations passes through the LayerNorm
+ # layer.
+ prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds)
+
+ if self.text_encoder is not None:
+ prompt_embeds_dtype = self.text_encoder.dtype
+ elif self.unet is not None:
+ prompt_embeds_dtype = self.unet.dtype
+ else:
+ prompt_embeds_dtype = prompt_embeds.dtype
+
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif prompt is not None and type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: process multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND:
+ # Retrieve the original scale by scaling back the LoRA layers
+ unscale_lora_layers(self.text_encoder, lora_scale)
+
+ return prompt_embeds, negative_prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image
+ def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None):
+ dtype = next(self.image_encoder.parameters()).dtype
+
+ if not isinstance(image, torch.Tensor):
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
+
+ image = image.to(device=device, dtype=dtype)
+ if output_hidden_states:
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_enc_hidden_states = self.image_encoder(
+ torch.zeros_like(image), output_hidden_states=True
+ ).hidden_states[-2]
+ uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave(
+ num_images_per_prompt, dim=0
+ )
+ return image_enc_hidden_states, uncond_image_enc_hidden_states
+ else:
+ image_embeds = self.image_encoder(image).image_embeds
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
+ uncond_image_embeds = torch.zeros_like(image_embeds)
+
+ return image_embeds, uncond_image_embeds
+
+ def decode_latents(self, latents, weight_dtype):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ video = []
+ for t in range(latents.shape[0]):
+ video.append(self.vae.decode(latents[t:t+1, ...].to(weight_dtype)).sample)
+ video = torch.concat(video, dim=0)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ video = video.float()
+ return video
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds
+ def prepare_ip_adapter_image_embeds(
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
+ ):
+ if ip_adapter_image_embeds is None:
+ if not isinstance(ip_adapter_image, list):
+ ip_adapter_image = [ip_adapter_image]
+
+ if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers):
+ raise ValueError(
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
+ )
+
+ image_embeds = []
+ for single_ip_adapter_image, image_proj_layer in zip(
+ ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
+ ):
+ output_hidden_state = not isinstance(image_proj_layer, ImageProjection)
+ single_image_embeds, single_negative_image_embeds = self.encode_image(
+ single_ip_adapter_image, device, 1, output_hidden_state
+ )
+ single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
+ single_negative_image_embeds = torch.stack(
+ [single_negative_image_embeds] * num_images_per_prompt, dim=0
+ )
+
+ if do_classifier_free_guidance:
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ single_image_embeds = single_image_embeds.to(device)
+
+ image_embeds.append(single_image_embeds)
+ else:
+ repeat_dims = [1]
+ image_embeds = []
+ for single_image_embeds in ip_adapter_image_embeds:
+ if do_classifier_free_guidance:
+ single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ single_negative_image_embeds = single_negative_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
+ )
+ single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
+ else:
+ single_image_embeds = single_image_embeds.repeat(
+ num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
+ )
+ image_embeds.append(single_image_embeds)
+
+ return image_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
+ def run_safety_checker(self, image, device, dtype):
+ if self.safety_checker is None:
+ has_nsfw_concept = None
+ else:
+ if torch.is_tensor(image):
+ feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
+ else:
+ feature_extractor_input = self.image_processor.numpy_to_pil(image)
+ safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
+ image, has_nsfw_concept = self.safety_checker(
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
+ )
+ return image, has_nsfw_concept
+
+ # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents
+ def decode_latents(self, latents, weight_dtype):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ video = []
+ for t in range(latents.shape[0]):
+ video.append(self.vae.decode(latents[t:t+1, ...].to(weight_dtype)).sample)
+ video = torch.concat(video, dim=0)
+
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ video = video.float()
+ return video
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ images,
+ masks,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ip_adapter_image=None,
+ ip_adapter_image_embeds=None,
+ brushnet_conditioning_scale=1.0,
+ control_guidance_start=0.0,
+ control_guidance_end=1.0,
+ callback_on_step_end_tensor_inputs=None,
+ ):
+ if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if callback_on_step_end_tensor_inputs is not None and not all(
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
+ ):
+ raise ValueError(
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ # Check `image`
+ is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance(
+ self.brushnet, torch._dynamo.eval_frame.OptimizedModule
+ )
+ if (
+ isinstance(self.brushnet, BrushNetModel)
+ or is_compiled
+ and isinstance(self.brushnet._orig_mod, BrushNetModel)
+ ):
+ self.check_image(images, masks, prompt, prompt_embeds)
+ else:
+ assert False
+
+ # Check `brushnet_conditioning_scale`
+ if (
+ isinstance(self.brushnet, BrushNetModel)
+ or is_compiled
+ and isinstance(self.brushnet._orig_mod, BrushNetModel)
+ ):
+ if not isinstance(brushnet_conditioning_scale, float):
+ raise TypeError("For single brushnet: `brushnet_conditioning_scale` must be type `float`.")
+ else:
+ assert False
+
+ if not isinstance(control_guidance_start, (tuple, list)):
+ control_guidance_start = [control_guidance_start]
+
+ if not isinstance(control_guidance_end, (tuple, list)):
+ control_guidance_end = [control_guidance_end]
+
+ if len(control_guidance_start) != len(control_guidance_end):
+ raise ValueError(
+ f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list."
+ )
+
+ for start, end in zip(control_guidance_start, control_guidance_end):
+ if start >= end:
+ raise ValueError(
+ f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}."
+ )
+ if start < 0.0:
+ raise ValueError(f"control guidance start: {start} can't be smaller than 0.")
+ if end > 1.0:
+ raise ValueError(f"control guidance end: {end} can't be larger than 1.0.")
+
+ if ip_adapter_image is not None and ip_adapter_image_embeds is not None:
+ raise ValueError(
+ "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined."
+ )
+
+ if ip_adapter_image_embeds is not None:
+ if not isinstance(ip_adapter_image_embeds, list):
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}"
+ )
+ elif ip_adapter_image_embeds[0].ndim not in [3, 4]:
+ raise ValueError(
+ f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D"
+ )
+
+ def check_image(self, images, masks, prompt, prompt_embeds):
+ for image in images:
+ image_is_pil = isinstance(image, PIL.Image.Image)
+ image_is_tensor = isinstance(image, torch.Tensor)
+ image_is_np = isinstance(image, np.ndarray)
+ image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image)
+ image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor)
+ image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray)
+
+ if (
+ not image_is_pil
+ and not image_is_tensor
+ and not image_is_np
+ and not image_is_pil_list
+ and not image_is_tensor_list
+ and not image_is_np_list
+ ):
+ raise TypeError(
+ f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}"
+ )
+ for mask in masks:
+ mask_is_pil = isinstance(mask, PIL.Image.Image)
+ mask_is_tensor = isinstance(mask, torch.Tensor)
+ mask_is_np = isinstance(mask, np.ndarray)
+ mask_is_pil_list = isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image)
+ mask_is_tensor_list = isinstance(mask, list) and isinstance(mask[0], torch.Tensor)
+ mask_is_np_list = isinstance(mask, list) and isinstance(mask[0], np.ndarray)
+
+ if (
+ not mask_is_pil
+ and not mask_is_tensor
+ and not mask_is_np
+ and not mask_is_pil_list
+ and not mask_is_tensor_list
+ and not mask_is_np_list
+ ):
+ raise TypeError(
+ f"mask must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(mask)}"
+ )
+
+ if image_is_pil:
+ image_batch_size = 1
+ else:
+ image_batch_size = len(image)
+
+ if prompt is not None and isinstance(prompt, str):
+ prompt_batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ prompt_batch_size = len(prompt)
+ elif prompt_embeds is not None:
+ prompt_batch_size = prompt_embeds.shape[0]
+
+ if image_batch_size != 1 and image_batch_size != prompt_batch_size:
+ raise ValueError(
+ f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}"
+ )
+
+ def prepare_image(
+ self,
+ images,
+ width,
+ height,
+ batch_size,
+ num_images_per_prompt,
+ device,
+ dtype,
+ do_classifier_free_guidance=False,
+ guess_mode=False,
+ ):
+ images_new = []
+ for image in images:
+ image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
+ image_batch_size = image.shape[0]
+
+ if image_batch_size == 1:
+ repeat_by = batch_size
+ else:
+ # image batch size is the same as prompt batch size
+ repeat_by = num_images_per_prompt
+
+ image = image.repeat_interleave(repeat_by, dim=0)
+
+ image = image.to(device=device, dtype=dtype)
+
+ # if do_classifier_free_guidance and not guess_mode:
+ # image = torch.cat([image] * 2)
+ images_new.append(image)
+
+ return images_new
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
+ def prepare_latents(self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None):
+ # shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ #b,c,n,h,w
+ shape = (
+ batch_size,
+ num_channels_latents,
+ num_frames,
+ height // self.vae_scale_factor,
+ width // self.vae_scale_factor
+ )
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ # noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ noise = rearrange(randn_tensor(shape, generator=generator, device=device, dtype=dtype), "b c t h w -> (b t) c h w")
+ else:
+ noise = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = noise * self.scheduler.init_noise_sigma
+ return latents, noise
+
+ @staticmethod
+ def temp_blend(a, b, overlap):
+ factor = torch.arange(overlap).to(b.device).view(overlap, 1, 1, 1) / (overlap - 1)
+ a[:overlap, ...] = (1 - factor) * a[:overlap, ...] + factor * b[:overlap, ...]
+ a[overlap:, ...] = b[overlap:, ...]
+ return a
+
+ # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding
+ def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
+ """
+ See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298
+
+ Args:
+ timesteps (`torch.Tensor`):
+ generate embedding vectors at these timesteps
+ embedding_dim (`int`, *optional*, defaults to 512):
+ dimension of the embeddings to generate
+ dtype:
+ data type of the generated embeddings
+
+ Returns:
+ `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)`
+ """
+ assert len(w.shape) == 1
+ w = w * 1000.0
+
+ half_dim = embedding_dim // 2
+ emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb)
+ emb = w.to(dtype)[:, None] * emb[None, :]
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
+ if embedding_dim % 2 == 1: # zero pad
+ emb = torch.nn.functional.pad(emb, (0, 1))
+ assert emb.shape == (w.shape[0], embedding_dim)
+ return emb
+
+ @property
+ def guidance_scale(self):
+ return self._guidance_scale
+
+ @property
+ def clip_skip(self):
+ return self._clip_skip
+
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ @property
+ def do_classifier_free_guidance(self):
+ return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None
+
+ @property
+ def cross_attention_kwargs(self):
+ return self._cross_attention_kwargs
+
+ @property
+ def num_timesteps(self):
+ return self._num_timesteps
+
+ # based on BrushNet: https://github.com/TencentARC/BrushNet/blob/main/src/diffusers/pipelines/brushnet/pipeline_brushnet.py
+ @torch.no_grad()
+ def __call__(
+ self,
+ num_frames: Optional[int] = 24,
+ prompt: Union[str, List[str]] = None,
+ images: PipelineImageInput = None, ##masked images
+ masks: PipelineImageInput = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ num_inference_steps: int = 50,
+ timesteps: List[int] = None,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ip_adapter_image: Optional[PipelineImageInput] = None,
+ ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ brushnet_conditioning_scale: Union[float, List[float]] = 1.0,
+ guess_mode: bool = False,
+ control_guidance_start: Union[float, List[float]] = 0.0,
+ control_guidance_end: Union[float, List[float]] = 1.0,
+ clip_skip: Optional[int] = None,
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
+ **kwargs,
+ ):
+ r"""
+ The call function to the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
+ image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The BrushNet branch input condition to provide guidance to the `unet` for generation.
+ mask (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,:
+ `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
+ The BrushNet branch input condition to provide guidance to the `unet` for generation.
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ timesteps (`List[int]`, *optional*):
+ Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
+ in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
+ passed will be used. Must be in descending order.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ A higher guidance scale value encourages the model to generate images closely linked to the text
+ `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide what to not include in image generation. If not defined, you need to
+ pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies
+ to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
+ generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor is generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not
+ provided, text embeddings are generated from the `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If
+ not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument.
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
+ ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*):
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters.
+ Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding
+ if `do_classifier_free_guidance` is set to `True`.
+ If not provided, embeddings are computed from the `ip_adapter_image` input argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that calls every `callback_steps` steps during inference. The function is called with the
+ following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function is called. If not specified, the callback is called at
+ every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ brushnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The outputs of the BrushNet are multiplied by `brushnet_conditioning_scale` before they are added
+ to the residual in the original `unet`. If multiple BrushNets are specified in `init`, you can set
+ the corresponding scale as a list.
+ guess_mode (`bool`, *optional*, defaults to `False`):
+ The BrushNet encoder tries to recognize the content of the input image even if you remove all
+ prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended.
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
+ The percentage of total steps at which the BrushNet starts applying.
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
+ The percentage of total steps at which the BrushNet stops applying.
+ clip_skip (`int`, *optional*):
+ Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
+ the output of the pre-final layer will be used for computing the prompt embeddings.
+ callback_on_step_end (`Callable`, *optional*):
+ A function that calls at the end of each denoising steps during the inference. The function is called
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
+ `callback_on_step_end_tensor_inputs`.
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
+ `._callback_tensor_inputs` attribute of your pipeine class.
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned,
+ otherwise a `tuple` is returned where the first element is a list with the generated images and the
+ second element is a list of `bool`s indicating whether the corresponding generated image contains
+ "not-safe-for-work" (nsfw) content.
+ """
+
+ callback = kwargs.pop("callback", None)
+ callback_steps = kwargs.pop("callback_steps", None)
+
+ if callback is not None:
+ deprecate(
+ "callback",
+ "1.0.0",
+ "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+ if callback_steps is not None:
+ deprecate(
+ "callback_steps",
+ "1.0.0",
+ "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`",
+ )
+
+ brushnet = self.brushnet._orig_mod if is_compiled_module(self.brushnet) else self.brushnet
+
+ # align format for control guidance
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
+ control_guidance_start, control_guidance_end = (
+ [control_guidance_start],
+ [control_guidance_end],
+ )
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt,
+ images,
+ masks,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ brushnet_conditioning_scale,
+ control_guidance_start,
+ control_guidance_end,
+ callback_on_step_end_tensor_inputs,
+ )
+
+ self._guidance_scale = guidance_scale
+ self._clip_skip = clip_skip
+ self._cross_attention_kwargs = cross_attention_kwargs
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+
+ global_pool_conditions = (
+ brushnet.config.global_pool_conditions
+ if isinstance(brushnet, BrushNetModel)
+ else brushnet.nets[0].config.global_pool_conditions
+ )
+ guess_mode = guess_mode or global_pool_conditions
+ video_length = len(images)
+
+ # 3. Encode input prompt
+ text_encoder_lora_scale = (
+ self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None
+ )
+ prompt_embeds, negative_prompt_embeds = self.encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ lora_scale=text_encoder_lora_scale,
+ clip_skip=self.clip_skip,
+ )
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ if self.do_classifier_free_guidance:
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
+ image_embeds = self.prepare_ip_adapter_image_embeds(
+ ip_adapter_image,
+ ip_adapter_image_embeds,
+ device,
+ batch_size * num_images_per_prompt,
+ self.do_classifier_free_guidance,
+ )
+
+ # 4. Prepare image
+ if isinstance(brushnet, BrushNetModel):
+ images = self.prepare_image(
+ images=images,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=brushnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ original_masks = self.prepare_image(
+ images=masks,
+ width=width,
+ height=height,
+ batch_size=batch_size * num_images_per_prompt,
+ num_images_per_prompt=num_images_per_prompt,
+ device=device,
+ dtype=brushnet.dtype,
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
+ guess_mode=guess_mode,
+ )
+ original_masks_new = []
+ for original_mask in original_masks:
+ original_mask=(original_mask.sum(1)[:,None,:,:] < 0).to(images[0].dtype)
+ original_masks_new.append(original_mask)
+ original_masks = original_masks_new
+
+ height, width = images[0].shape[-2:]
+ else:
+ assert False
+
+ # 5. Prepare timesteps
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
+ self._num_timesteps = len(timesteps)
+
+ # 6. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents, noise = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ num_frames,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6.1 prepare condition latents
+ images = torch.cat(images)
+ images = images.to(dtype=images[0].dtype)
+ conditioning_latents = []
+ num=4
+ for i in range(0, images.shape[0], num):
+ conditioning_latents.append(self.vae.encode(images[i : i + num]).latent_dist.sample())
+ conditioning_latents = torch.cat(conditioning_latents, dim=0)
+
+ conditioning_latents = conditioning_latents * self.vae.config.scaling_factor #[(f c h w],c2=4
+
+ original_masks = torch.cat(original_masks)
+ masks = torch.nn.functional.interpolate(
+ original_masks,
+ size=(
+ latents.shape[-2],
+ latents.shape[-1]
+ )
+ ) ##[ f c h w],c=1
+
+ conditioning_latents=torch.concat([conditioning_latents,masks],1)
+
+ # 6.5 Optionally get Guidance Scale Embedding
+ timestep_cond = None
+ if self.unet.config.time_cond_proj_dim is not None:
+ guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt)
+ timestep_cond = self.get_guidance_scale_embedding(
+ guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
+ ).to(device=device, dtype=latents.dtype)
+
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7.1 Add image embeds for IP-Adapter
+ added_cond_kwargs = (
+ {"image_embeds": image_embeds}
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None
+ else None
+ )
+
+ # 7.2 Create tensor stating which brushnets to keep
+ brushnet_keep = []
+ for i in range(len(timesteps)):
+ keeps = [
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
+ for s, e in zip(control_guidance_start, control_guidance_end)
+ ]
+ brushnet_keep.append(keeps[0] if isinstance(brushnet, BrushNetModel) else keeps)
+
+
+ overlap = num_frames//4
+ context_list, context_list_swap = get_frames_context_swap(video_length, overlap=overlap, num_frames_per_clip=num_frames)
+ scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(context_list)
+ scheduler_status_swap = [copy.deepcopy(self.scheduler.__dict__)] * len(context_list_swap)
+ count = torch.zeros_like(latents)
+ value = torch.zeros_like(latents)
+
+
+ # 8. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ is_unet_compiled = is_compiled_module(self.unet)
+ is_brushnet_compiled = is_compiled_module(self.brushnet)
+ is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1")
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+
+ count.zero_()
+ value.zero_()
+ ## swap
+ if (i%2==1):
+ context_list_choose = context_list_swap
+ scheduler_status_choose = scheduler_status_swap
+ else:
+ context_list_choose = context_list
+ scheduler_status_choose = scheduler_status
+
+
+ for j, context in enumerate(context_list_choose):
+ self.scheduler.__dict__.update(scheduler_status_choose[j])
+
+ latents_j = latents[context, :, :, :]
+
+ # Relevant thread:
+ # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428
+ if (is_unet_compiled and is_brushnet_compiled) and is_torch_higher_equal_2_1:
+ torch._inductor.cudagraph_mark_step_begin()
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents_j] * 2) if self.do_classifier_free_guidance else latents_j
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # brushnet(s) inference
+ if guess_mode and self.do_classifier_free_guidance:
+ # Infer BrushNet only for the conditional batch.
+ control_model_input = latents_j
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
+ brushnet_prompt_embeds = prompt_embeds.chunk(2)[1]
+ brushnet_prompt_embeds = rearrange(repeat(brushnet_prompt_embeds, "b c d -> b t c d", t=num_frames), 'b t c d -> (b t) c d')
+ else:
+ control_model_input = latent_model_input
+ brushnet_prompt_embeds = prompt_embeds
+ if self.do_classifier_free_guidance:
+ neg_brushnet_prompt_embeds, brushnet_prompt_embeds = brushnet_prompt_embeds.chunk(2)
+ brushnet_prompt_embeds = rearrange(repeat(brushnet_prompt_embeds, "b c d -> b t c d", t=num_frames), 'b t c d -> (b t) c d')
+ neg_brushnet_prompt_embeds = rearrange(repeat(neg_brushnet_prompt_embeds, "b c d -> b t c d", t=num_frames), 'b t c d -> (b t) c d')
+ brushnet_prompt_embeds = torch.cat([neg_brushnet_prompt_embeds, brushnet_prompt_embeds])
+ else:
+ brushnet_prompt_embeds = rearrange(repeat(brushnet_prompt_embeds, "b c d -> b t c d", t=num_frames), 'b t c d -> (b t) c d')
+
+ if isinstance(brushnet_keep[i], list):
+ cond_scale = [c * s for c, s in zip(brushnet_conditioning_scale, brushnet_keep[i])]
+ else:
+ brushnet_cond_scale = brushnet_conditioning_scale
+ if isinstance(brushnet_cond_scale, list):
+ brushnet_cond_scale = brushnet_cond_scale[0]
+ cond_scale = brushnet_cond_scale * brushnet_keep[i]
+
+
+ down_block_res_samples, mid_block_res_sample, up_block_res_samples = self.brushnet(
+ control_model_input,
+ t,
+ encoder_hidden_states=brushnet_prompt_embeds,
+ brushnet_cond=torch.cat([conditioning_latents[context, :, :, :]]*2) if self.do_classifier_free_guidance else conditioning_latents[context, :, :, :],
+ conditioning_scale=cond_scale,
+ guess_mode=guess_mode,
+ return_dict=False,
+ )
+
+ if guess_mode and self.do_classifier_free_guidance:
+ # Infered BrushNet only for the conditional batch.
+ # To apply the output of BrushNet to both the unconditional and conditional batches,
+ # add 0 to the unconditional batch to keep it unchanged.
+ down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
+ mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
+ up_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in up_block_res_samples]
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ timestep_cond=timestep_cond,
+ cross_attention_kwargs=self.cross_attention_kwargs,
+ down_block_add_samples=down_block_res_samples,
+ mid_block_add_sample=mid_block_res_sample,
+ up_block_add_samples=up_block_res_samples,
+ added_cond_kwargs=added_cond_kwargs,
+ return_dict=False,
+ num_frames=num_frames,
+ )[0]
+
+ # perform guidance
+ if self.do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents_j = self.scheduler.step(noise_pred, t, latents_j, **extra_step_kwargs, return_dict=False)[0]
+
+ count[context, ...] += 1
+
+ if j==0:
+ value[context, ...] += latents_j
+ else:
+ overlap_index_list = [index for index, value in enumerate(count[context, 0, 0, 0]) if value > 1]
+ overlap_cur = len(overlap_index_list)
+ ratio_next = torch.linspace(0, 1, overlap_cur+2)[1:-1]
+ ratio_pre = 1-ratio_next
+ for i_overlap in overlap_index_list:
+ value[context[i_overlap], ...] = value[context[i_overlap], ...]*ratio_pre[i_overlap] + latents_j[i_overlap, ...]*ratio_next[i_overlap]
+ value[context[i_overlap:num_frames], ...] = latents_j[i_overlap:num_frames, ...]
+
+ latents = value.clone()
+
+ if callback_on_step_end is not None:
+ callback_kwargs = {}
+ for k in callback_on_step_end_tensor_inputs:
+ callback_kwargs[k] = locals()[k]
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
+
+ latents = callback_outputs.pop("latents", latents)
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
+ negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ step_idx = i // getattr(self.scheduler, "order", 1)
+ callback(step_idx, t, latents)
+
+
+ # If we do sequential model offloading, let's offload unet and brushnet
+ # manually for max memory savings
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.unet.to("cpu")
+ self.brushnet.to("cpu")
+ torch.cuda.empty_cache()
+
+ if output_type == "latent":
+ image = latents
+ has_nsfw_concept = None
+ return DiffuEraserPipelineOutput(frames=image, nsfw_content_detected=has_nsfw_concept)
+
+ video_tensor = self.decode_latents(latents, weight_dtype=prompt_embeds.dtype)
+
+ if output_type == "pt":
+ video = video_tensor
+ else:
+ video = self.image_processor.postprocess(video_tensor, output_type=output_type)
+
+ # Offload all models
+ self.maybe_free_model_hooks()
+
+ if not return_dict:
+ return (video, has_nsfw_concept)
+
+ return DiffuEraserPipelineOutput(frames=video, latents=latents)
diff --git a/examples/example1/mask.mp4 b/examples/example1/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..065f632cfe456ba808fa1c63fb7285fb52aef8b5
Binary files /dev/null and b/examples/example1/mask.mp4 differ
diff --git a/examples/example1/video.mp4 b/examples/example1/video.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..c9c3acec1b37146bb9379de7df714fe013f4806b
Binary files /dev/null and b/examples/example1/video.mp4 differ
diff --git a/examples/example2/mask.mp4 b/examples/example2/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..463ccf51075591edd7e2d9f8b528c29d786b63d1
--- /dev/null
+++ b/examples/example2/mask.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:39849531b31960ee023cd33caf402afd4a4c1402276ba8afa04b7888feb52c3f
+size 1249680
diff --git a/examples/example2/video.mp4 b/examples/example2/video.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..992b5fc234a255fc104f58f7bded5455b422f171
Binary files /dev/null and b/examples/example2/video.mp4 differ
diff --git a/examples/example3/mask.mp4 b/examples/example3/mask.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..fe2ee61cb7930d81bc10baa3bb6d5983ae8940e9
Binary files /dev/null and b/examples/example3/mask.mp4 differ
diff --git a/examples/example3/video.mp4 b/examples/example3/video.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..4671b8fca23896f621d3f0426e18073392254e35
--- /dev/null
+++ b/examples/example3/video.mp4
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:b21c936a305f80ed6707bad621712b24bd1e7a69f82ec7cdd949b18fd1a7fd56
+size 5657081
diff --git a/libs/brushnet_CA.py b/libs/brushnet_CA.py
new file mode 100644
index 0000000000000000000000000000000000000000..cab2b6bca90530697607109d26491378862cde32
--- /dev/null
+++ b/libs/brushnet_CA.py
@@ -0,0 +1,939 @@
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from torch.nn import functional as F
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput, logging
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps
+from diffusers.models.modeling_utils import ModelMixin
+from .unet_2d_blocks import (
+ CrossAttnDownBlock2D,
+ DownBlock2D,
+ UNetMidBlock2D,
+ UNetMidBlock2DCrossAttn,
+ get_down_block,
+ get_mid_block,
+ get_up_block,
+ MidBlock2D
+)
+
+# from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel
+from libs.unet_2d_condition import UNet2DConditionModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class BrushNetOutput(BaseOutput):
+ """
+ The output of [`BrushNetModel`].
+
+ Args:
+ up_block_res_samples (`tuple[torch.Tensor]`):
+ A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
+ used to condition the original UNet's upsampling activations.
+ down_block_res_samples (`tuple[torch.Tensor]`):
+ A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should
+ be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be
+ used to condition the original UNet's downsampling activations.
+ mid_down_block_re_sample (`torch.Tensor`):
+ The activation of the midde block (the lowest sample resolution). Each tensor should be of shape
+ `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`.
+ Output can be used to condition the original UNet's middle block activation.
+ """
+
+ up_block_res_samples: Tuple[torch.Tensor]
+ down_block_res_samples: Tuple[torch.Tensor]
+ mid_block_res_sample: torch.Tensor
+
+
+class BrushNetModel(ModelMixin, ConfigMixin):
+ """
+ A BrushNet model.
+
+ Args:
+ in_channels (`int`, defaults to 4):
+ The number of channels in the input sample.
+ flip_sin_to_cos (`bool`, defaults to `True`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, defaults to 0):
+ The frequency shift to apply to the time embedding.
+ down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
+ The tuple of upsample blocks to use.
+ only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`):
+ block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, defaults to 2):
+ The number of layers per block.
+ downsample_padding (`int`, defaults to 1):
+ The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, defaults to 1):
+ The scale factor to use for the mid block.
+ act_fn (`str`, defaults to "silu"):
+ The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use for the normalization. If None, normalization and activation layers is skipped
+ in post-processing.
+ norm_eps (`float`, defaults to 1e-5):
+ The epsilon to use for the normalization.
+ cross_attention_dim (`int`, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8):
+ The dimension of the attention heads.
+ use_linear_projection (`bool`, defaults to `False`):
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ num_class_embeds (`int`, *optional*, defaults to 0):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ upcast_attention (`bool`, defaults to `False`):
+ resnet_time_scale_shift (`str`, defaults to `"default"`):
+ Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`.
+ projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`):
+ The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when
+ `class_embed_type="projection"`.
+ brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`):
+ The channel order of conditional image. Will convert to `rgb` if it's `bgr`.
+ conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`):
+ The tuple of output channel for each block in the `conditioning_embedding` layer.
+ global_pool_conditions (`bool`, defaults to `False`):
+ TODO(Patrick) - unused parameter.
+ addition_embed_type_num_heads (`int`, defaults to 64):
+ The number of heads to use for the `TextTimeEmbedding` layer.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 4,
+ conditioning_channels: int = 5,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str, ...] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+ up_block_types: Tuple[str, ...] = (
+ "UpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D",
+ "CrossAttnUpBlock2D",
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ brushnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ global_pool_conditions: bool = False,
+ addition_embed_type_num_heads: int = 64,
+ ):
+ super().__init__()
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ # input
+ conv_in_kernel = 3
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in_condition = nn.Conv2d(
+ in_channels+conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ )
+
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ self.down_blocks = nn.ModuleList([])
+ self.brushnet_down_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ brushnet_block = zero_module(brushnet_block)
+ self.brushnet_down_blocks.append(brushnet_block) #零卷积
+
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[i],
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ downsample_padding=downsample_padding,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ self.down_blocks.append(down_block)
+
+ for _ in range(layers_per_block):
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ brushnet_block = zero_module(brushnet_block)
+ self.brushnet_down_blocks.append(brushnet_block) #零卷积
+
+ if not is_final_block:
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ brushnet_block = zero_module(brushnet_block)
+ self.brushnet_down_blocks.append(brushnet_block)
+
+ # mid
+ mid_block_channel = block_out_channels[-1]
+
+ brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1)
+ brushnet_block = zero_module(brushnet_block)
+ self.brushnet_mid_block = brushnet_block
+
+ self.mid_block = get_mid_block(
+ mid_block_type,
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ in_channels=mid_block_channel,
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block)))
+ only_cross_attention = list(reversed(only_cross_attention))
+
+ output_channel = reversed_block_out_channels[0]
+
+ self.up_blocks = nn.ModuleList([])
+ self.brushnet_up_blocks = nn.ModuleList([])
+
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block+1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resolution_idx=i,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=reversed_num_attention_heads[i],
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ for _ in range(layers_per_block+1):
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ brushnet_block = zero_module(brushnet_block)
+ self.brushnet_up_blocks.append(brushnet_block)
+
+ if not is_final_block:
+ brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1)
+ brushnet_block = zero_module(brushnet_block)
+ self.brushnet_up_blocks.append(brushnet_block)
+
+
+ @classmethod
+ def from_unet(
+ cls,
+ unet: UNet2DConditionModel,
+ brushnet_conditioning_channel_order: str = "rgb",
+ conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256),
+ load_weights_from_unet: bool = True,
+ conditioning_channels: int = 5,
+ ):
+ r"""
+ Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`].
+
+ Parameters:
+ unet (`UNet2DConditionModel`):
+ The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied
+ where applicable.
+ """
+ transformer_layers_per_block = (
+ unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1
+ )
+ encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None
+ encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None
+ addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None
+ addition_time_embed_dim = (
+ unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None
+ )
+
+ brushnet = cls(
+ in_channels=unet.config.in_channels,
+ conditioning_channels=conditioning_channels,
+ flip_sin_to_cos=unet.config.flip_sin_to_cos,
+ freq_shift=unet.config.freq_shift,
+ # down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'],
+ down_block_types=[
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ],
+ # mid_block_type='MidBlock2D',
+ mid_block_type="UNetMidBlock2DCrossAttn",
+ # up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'],
+ up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"],
+ only_cross_attention=unet.config.only_cross_attention,
+ block_out_channels=unet.config.block_out_channels,
+ layers_per_block=unet.config.layers_per_block,
+ downsample_padding=unet.config.downsample_padding,
+ mid_block_scale_factor=unet.config.mid_block_scale_factor,
+ act_fn=unet.config.act_fn,
+ norm_num_groups=unet.config.norm_num_groups,
+ norm_eps=unet.config.norm_eps,
+ cross_attention_dim=unet.config.cross_attention_dim,
+ transformer_layers_per_block=transformer_layers_per_block,
+ encoder_hid_dim=encoder_hid_dim,
+ encoder_hid_dim_type=encoder_hid_dim_type,
+ attention_head_dim=unet.config.attention_head_dim,
+ num_attention_heads=unet.config.num_attention_heads,
+ use_linear_projection=unet.config.use_linear_projection,
+ class_embed_type=unet.config.class_embed_type,
+ addition_embed_type=addition_embed_type,
+ addition_time_embed_dim=addition_time_embed_dim,
+ num_class_embeds=unet.config.num_class_embeds,
+ upcast_attention=unet.config.upcast_attention,
+ resnet_time_scale_shift=unet.config.resnet_time_scale_shift,
+ projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim,
+ brushnet_conditioning_channel_order=brushnet_conditioning_channel_order,
+ conditioning_embedding_out_channels=conditioning_embedding_out_channels,
+ )
+
+ if load_weights_from_unet:
+ conv_in_condition_weight=torch.zeros_like(brushnet.conv_in_condition.weight)
+ conv_in_condition_weight[:,:4,...]=unet.conv_in.weight
+ conv_in_condition_weight[:,4:8,...]=unet.conv_in.weight
+ brushnet.conv_in_condition.weight=torch.nn.Parameter(conv_in_condition_weight)
+ brushnet.conv_in_condition.bias=unet.conv_in.bias
+
+ brushnet.time_proj.load_state_dict(unet.time_proj.state_dict())
+ brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict())
+
+ if brushnet.class_embedding:
+ brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict())
+
+ brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(),strict=False)
+ brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(),strict=False)
+ brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(),strict=False)
+
+ return brushnet
+
+ @property
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None:
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
+ if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ brushnet_cond: torch.FloatTensor,
+ conditioning_scale: float = 1.0,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ guess_mode: bool = False,
+ return_dict: bool = True,
+ ) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]:
+ """
+ The [`BrushNetModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor.
+ timestep (`Union[torch.Tensor, float, int]`):
+ The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.Tensor`):
+ The encoder hidden states.
+ brushnet_cond (`torch.FloatTensor`):
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
+ conditioning_scale (`float`, defaults to `1.0`):
+ The scale factor for BrushNet outputs.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond (`torch.Tensor`, *optional*, defaults to `None`):
+ Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the
+ timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep
+ embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ added_cond_kwargs (`dict`):
+ Additional conditions for the Stable Diffusion XL UNet.
+ cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor`.
+ guess_mode (`bool`, defaults to `False`):
+ In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if
+ you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended.
+ return_dict (`bool`, defaults to `True`):
+ Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.brushnet.BrushNetOutput`] **or** `tuple`:
+ If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is
+ returned where the first element is the sample tensor.
+ """
+ # check channel order
+ channel_order = self.config.brushnet_conditioning_channel_order
+
+ if channel_order == "rgb":
+ # in rgb order by default
+ ...
+ elif channel_order == "bgr":
+ brushnet_cond = torch.flip(brushnet_cond, dims=[1])
+ else:
+ raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}")
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ emb = emb + class_emb
+
+ if self.config.addition_embed_type is not None:
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+
+ elif self.config.addition_embed_type == "text_time":
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ # 2. pre-process
+ brushnet_cond=torch.concat([sample,brushnet_cond],1)
+ sample = self.conv_in_condition(brushnet_cond)
+
+
+ # 3. down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 4. PaintingNet down blocks
+ brushnet_down_block_res_samples = ()
+ for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks):
+ down_block_res_sample = brushnet_down_block(down_block_res_sample)
+ brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,)
+
+ # 5. mid
+ if self.mid_block is not None:
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ )
+ else:
+ sample = self.mid_block(sample, emb)
+
+ # 6. BrushNet mid blocks
+ brushnet_mid_block_res_sample = self.brushnet_mid_block(sample)
+
+
+ # 7. up
+ up_block_res_samples = ()
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample, up_res_samples = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ return_res_samples=True
+ )
+ else:
+ sample, up_res_samples = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ upsample_size=upsample_size,
+ return_res_samples=True
+ )
+
+ up_block_res_samples += up_res_samples
+
+ # 8. BrushNet up blocks
+ brushnet_up_block_res_samples = ()
+ for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks):
+ up_block_res_sample = brushnet_up_block(up_block_res_sample)
+ brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,)
+
+ # 6. scaling
+ if guess_mode and not self.config.global_pool_conditions:
+ scales = torch.logspace(-1, 0, len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples), device=sample.device) # 0.1 to 1.0
+ scales = scales * conditioning_scale
+
+ brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples, scales[:len(brushnet_down_block_res_samples)])]
+ brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)]
+ brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples, scales[len(brushnet_down_block_res_samples)+1:])]
+ else:
+ brushnet_down_block_res_samples = [sample * conditioning_scale for sample in brushnet_down_block_res_samples]
+ brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale
+ brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples]
+
+
+ if self.config.global_pool_conditions:
+ brushnet_down_block_res_samples = [
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples
+ ]
+ brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True)
+ brushnet_up_block_res_samples = [
+ torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples
+ ]
+
+ if not return_dict:
+ return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples)
+
+ return BrushNetOutput(
+ down_block_res_samples=brushnet_down_block_res_samples,
+ mid_block_res_sample=brushnet_mid_block_res_sample,
+ up_block_res_samples=brushnet_up_block_res_samples
+ )
+
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
diff --git a/libs/transformer_temporal.py b/libs/transformer_temporal.py
new file mode 100644
index 0000000000000000000000000000000000000000..32928afbb87f903e3b5b75507455b5876cc7a7bd
--- /dev/null
+++ b/libs/transformer_temporal.py
@@ -0,0 +1,375 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, Optional
+
+import torch
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput
+from diffusers.models.attention import BasicTransformerBlock, TemporalBasicTransformerBlock
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.resnet import AlphaBlender
+
+
+@dataclass
+class TransformerTemporalModelOutput(BaseOutput):
+ """
+ The output of [`TransformerTemporalModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input.
+ """
+
+ sample: torch.FloatTensor
+
+
+class TransformerTemporalModel(ModelMixin, ConfigMixin):
+ """
+ A Transformer model for video-like data.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ The number of channels in the input and output (specify if the input is **continuous**).
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ attention_bias (`bool`, *optional*):
+ Configure if the `TransformerBlock` attention should contain a bias parameter.
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
+ This is fixed during training since it is used to learn a number of position embeddings.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`):
+ Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported
+ activation functions.
+ norm_elementwise_affine (`bool`, *optional*):
+ Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization.
+ double_self_attention (`bool`, *optional*):
+ Configure if each `TransformerBlock` should contain two self-attention layers.
+ positional_embeddings: (`str`, *optional*):
+ The type of positional embeddings to apply to the sequence input before passing use.
+ num_positional_embeddings: (`int`, *optional*):
+ The maximum length of the sequence over which to apply positional embeddings.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ activation_fn: str = "geglu",
+ norm_elementwise_affine: bool = True,
+ double_self_attention: bool = True,
+ positional_embeddings: Optional[str] = None,
+ num_positional_embeddings: Optional[int] = None,
+ ):
+ super().__init__()
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ # 3. Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ double_self_attention=double_self_attention,
+ norm_elementwise_affine=norm_elementwise_affine,
+ positional_embeddings=positional_embeddings,
+ num_positional_embeddings=num_positional_embeddings,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ self.proj_out = nn.Linear(inner_dim, in_channels)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ timestep: Optional[torch.LongTensor] = None,
+ num_frames: int = 1,
+ encoder_hidden_states: Optional[torch.LongTensor] = None,
+ class_labels: torch.LongTensor = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> TransformerTemporalModelOutput:
+ """
+ The [`TransformerTemporal`] forward method.
+
+ Args:
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
+ Input hidden_states.
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.LongTensor`, *optional*):
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
+ `AdaLayerZeroNorm`.
+ num_frames (`int`, *optional*, defaults to 1):
+ The number of frames to be processed per batch. This is used to reshape the hidden states.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
+ returned, otherwise a `tuple` where the first element is the sample tensor.
+ """
+ # 1. Input
+ batch_frames, channel, height, width = hidden_states.shape
+ batch_size = batch_frames // num_frames
+
+ residual = hidden_states
+
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width)
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4)
+
+ hidden_states = self.norm(hidden_states)
+ hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
+
+ hidden_states = self.proj_in(hidden_states)
+
+ # 2. Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep=timestep,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=class_labels,
+ )
+
+ # 3. Output
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = (
+ hidden_states[None, None, :]
+ .reshape(batch_size, height, width, num_frames, channel)
+ .permute(0, 3, 4, 1, 2)
+ .contiguous()
+ )
+ hidden_states = hidden_states.reshape(batch_frames, channel, height, width)
+
+ output = hidden_states + residual
+
+ return output
+
+
+class TransformerSpatioTemporalModel(nn.Module):
+ """
+ A Transformer model for video-like data.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ The number of channels in the input and output (specify if the input is **continuous**).
+ out_channels (`int`, *optional*):
+ The number of channels in the output (specify if the input is **continuous**).
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ """
+
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: int = 320,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ cross_attention_dim: Optional[int] = None,
+ ):
+ super().__init__()
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+
+ inner_dim = num_attention_heads * attention_head_dim
+ self.inner_dim = inner_dim
+
+ # 2. Define input layers
+ self.in_channels = in_channels
+ self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6)
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+
+ # 3. Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ time_mix_inner_dim = inner_dim
+ self.temporal_transformer_blocks = nn.ModuleList(
+ [
+ TemporalBasicTransformerBlock(
+ inner_dim,
+ time_mix_inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ for _ in range(num_layers)
+ ]
+ )
+
+ time_embed_dim = in_channels * 4
+ self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels)
+ self.time_proj = Timesteps(in_channels, True, 0)
+ self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images")
+
+ # 4. Define output layers
+ self.out_channels = in_channels if out_channels is None else out_channels
+ # TODO: should use out_channels for continuous projections
+ self.proj_out = nn.Linear(inner_dim, in_channels)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ):
+ """
+ Args:
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
+ Input hidden_states.
+ num_frames (`int`):
+ The number of frames to be processed per batch. This is used to reshape the hidden states.
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*):
+ A tensor indicating whether the input contains only images. 1 indicates that the input contains only
+ images, 0 indicates that the input contains video frames.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is
+ returned, otherwise a `tuple` where the first element is the sample tensor.
+ """
+ # 1. Input
+ batch_frames, _, height, width = hidden_states.shape
+ num_frames = image_only_indicator.shape[-1]
+ batch_size = batch_frames // num_frames
+
+ time_context = encoder_hidden_states
+ time_context_first_timestep = time_context[None, :].reshape(
+ batch_size, num_frames, -1, time_context.shape[-1]
+ )[:, 0]
+ time_context = time_context_first_timestep[None, :].broadcast_to(
+ height * width, batch_size, 1, time_context.shape[-1]
+ )
+ time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1])
+
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+
+ num_frames_emb = torch.arange(num_frames, device=hidden_states.device)
+ num_frames_emb = num_frames_emb.repeat(batch_size, 1)
+ num_frames_emb = num_frames_emb.reshape(-1)
+ t_emb = self.time_proj(num_frames_emb)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=hidden_states.dtype)
+
+ emb = self.time_pos_embed(t_emb)
+ emb = emb[:, None, :]
+
+ # 2. Blocks
+ for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks):
+ if self.training and self.gradient_checkpointing:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ block,
+ hidden_states,
+ None,
+ encoder_hidden_states,
+ None,
+ use_reentrant=False,
+ )
+ else:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ )
+
+ hidden_states_mix = hidden_states
+ hidden_states_mix = hidden_states_mix + emb
+
+ hidden_states_mix = temporal_block(
+ hidden_states_mix,
+ num_frames=num_frames,
+ encoder_hidden_states=time_context,
+ )
+ hidden_states = self.time_mixer(
+ x_spatial=hidden_states,
+ x_temporal=hidden_states_mix,
+ image_only_indicator=image_only_indicator,
+ )
+
+ # 3. Output
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+
+ output = hidden_states + residual
+
+ if not return_dict:
+ return (output,)
+
+ return TransformerTemporalModelOutput(sample=output)
diff --git a/libs/unet_2d_blocks.py b/libs/unet_2d_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b8e6f51f3017b0fb8bd3f59cb05eb23c5f09a80
--- /dev/null
+++ b/libs/unet_2d_blocks.py
@@ -0,0 +1,3824 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Dict, Optional, Tuple, Union
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.utils import is_torch_version, logging
+from diffusers.utils.torch_utils import apply_freeu
+from diffusers.models.activations import get_activation
+from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0
+from diffusers.models.normalization import AdaGroupNorm
+from diffusers.models.resnet import (
+ Downsample2D,
+ FirDownsample2D,
+ FirUpsample2D,
+ KDownsample2D,
+ KUpsample2D,
+ ResnetBlock2D,
+ ResnetBlockCondNorm2D,
+ Upsample2D,
+)
+from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel
+from diffusers.models.transformers.transformer_2d import Transformer2DModel
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def get_down_block(
+ down_block_type: str,
+ num_layers: int,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ add_downsample: bool,
+ resnet_eps: float,
+ resnet_act_fn: str,
+ transformer_layers_per_block: int = 1,
+ num_attention_heads: Optional[int] = None,
+ resnet_groups: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ downsample_padding: Optional[int] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ attention_type: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: float = 1.0,
+ cross_attention_norm: Optional[str] = None,
+ attention_head_dim: Optional[int] = None,
+ downsample_type: Optional[str] = None,
+ dropout: float = 0.0,
+):
+ # If attn head dim is not defined, we default it to the number of heads
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
+ )
+ attention_head_dim = num_attention_heads
+
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownBlock2D":
+ return DownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "ResnetDownsampleBlock2D":
+ return ResnetDownsampleBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ output_scale_factor=resnet_out_scale_factor,
+ )
+ elif down_block_type == "AttnDownBlock2D":
+ if add_downsample is False:
+ downsample_type = None
+ else:
+ downsample_type = downsample_type or "conv" # default to 'conv'
+ return AttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ downsample_type=downsample_type,
+ )
+ elif down_block_type == "CrossAttnDownBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D")
+ return CrossAttnDownBlock2D(
+ num_layers=num_layers,
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ )
+ elif down_block_type == "SimpleCrossAttnDownBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D")
+ return SimpleCrossAttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ output_scale_factor=resnet_out_scale_factor,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif down_block_type == "SkipDownBlock2D":
+ return SkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "AttnSkipDownBlock2D":
+ return AttnSkipDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "DownEncoderBlock2D":
+ return DownEncoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "AttnDownEncoderBlock2D":
+ return AttnDownEncoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "KDownBlock2D":
+ return KDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif down_block_type == "KCrossAttnDownBlock2D":
+ return KCrossAttnDownBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ add_self_attention=True if not add_downsample else False,
+ )
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_mid_block(
+ mid_block_type: str,
+ temb_channels: int,
+ in_channels: int,
+ resnet_eps: float,
+ resnet_act_fn: str,
+ resnet_groups: int,
+ output_scale_factor: float = 1.0,
+ transformer_layers_per_block: int = 1,
+ num_attention_heads: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ mid_block_only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ attention_type: str = "default",
+ resnet_skip_time_act: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ attention_head_dim: Optional[int] = 1,
+ dropout: float = 0.0,
+):
+ if mid_block_type == "UNetMidBlock2DCrossAttn":
+ return UNetMidBlock2DCrossAttn(
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=in_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ resnet_groups=resnet_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn":
+ return UNetMidBlock2DSimpleCrossAttn(
+ in_channels=in_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ only_cross_attention=mid_block_only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif mid_block_type == "UNetMidBlock2D":
+ return UNetMidBlock2D(
+ in_channels=in_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ num_layers=0,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ add_attention=False,
+ )
+ elif mid_block_type == "MidBlock2D":
+ return MidBlock2D(
+ in_channels=in_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ )
+ elif mid_block_type == "MidBlock2D":
+ return MidBlock2D(
+ in_channels=in_channels,
+ temb_channels=temb_channels,
+ dropout=dropout,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resnet_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ )
+ elif mid_block_type is None:
+ return None
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+
+def get_up_block(
+ up_block_type: str,
+ num_layers: int,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ add_upsample: bool,
+ resnet_eps: float,
+ resnet_act_fn: str,
+ resolution_idx: Optional[int] = None,
+ transformer_layers_per_block: int = 1,
+ num_attention_heads: Optional[int] = None,
+ resnet_groups: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ attention_type: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: float = 1.0,
+ cross_attention_norm: Optional[str] = None,
+ attention_head_dim: Optional[int] = None,
+ upsample_type: Optional[str] = None,
+ dropout: float = 0.0,
+) -> nn.Module:
+ # If attn head dim is not defined, we default it to the number of heads
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}."
+ )
+ attention_head_dim = num_attention_heads
+
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpBlock2D":
+ return UpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "ResnetUpsampleBlock2D":
+ return ResnetUpsampleBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ output_scale_factor=resnet_out_scale_factor,
+ )
+ elif up_block_type == "CrossAttnUpBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D")
+ return CrossAttnUpBlock2D(
+ num_layers=num_layers,
+ transformer_layers_per_block=transformer_layers_per_block,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ )
+ elif up_block_type == "SimpleCrossAttnUpBlock2D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D")
+ return SimpleCrossAttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ skip_time_act=resnet_skip_time_act,
+ output_scale_factor=resnet_out_scale_factor,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+ elif up_block_type == "AttnUpBlock2D":
+ if add_upsample is False:
+ upsample_type = None
+ else:
+ upsample_type = upsample_type or "conv" # default to 'conv'
+
+ return AttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ upsample_type=upsample_type,
+ )
+ elif up_block_type == "SkipUpBlock2D":
+ return SkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "AttnSkipUpBlock2D":
+ return AttnSkipUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "UpDecoderBlock2D":
+ return UpDecoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temb_channels=temb_channels,
+ )
+ elif up_block_type == "AttnUpDecoderBlock2D":
+ return AttnUpDecoderBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ attention_head_dim=attention_head_dim,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temb_channels=temb_channels,
+ )
+ elif up_block_type == "KUpBlock2D":
+ return KUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ )
+ elif up_block_type == "KCrossAttnUpBlock2D":
+ return KCrossAttnUpBlock2D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ dropout=dropout,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ cross_attention_dim=cross_attention_dim,
+ attention_head_dim=attention_head_dim,
+ )
+
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class AutoencoderTinyBlock(nn.Module):
+ """
+ Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU
+ blocks.
+
+ Args:
+ in_channels (`int`): The number of input channels.
+ out_channels (`int`): The number of output channels.
+ act_fn (`str`):
+ ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
+
+ Returns:
+ `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
+ `out_channels`.
+ """
+
+ def __init__(self, in_channels: int, out_channels: int, act_fn: str):
+ super().__init__()
+ act_fn = get_activation(act_fn)
+ self.conv = nn.Sequential(
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
+ act_fn,
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
+ act_fn,
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
+ )
+ self.skip = (
+ nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
+ if in_channels != out_channels
+ else nn.Identity()
+ )
+ self.fuse = nn.ReLU()
+
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
+ return self.fuse(self.conv(x) + self.skip(x))
+
+
+class UNetMidBlock2D(nn.Module):
+ """
+ A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks.
+
+ Args:
+ in_channels (`int`): The number of input channels.
+ temb_channels (`int`): The number of temporal embedding channels.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout rate.
+ num_layers (`int`, *optional*, defaults to 1): The number of residual blocks.
+ resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks.
+ resnet_time_scale_shift (`str`, *optional*, defaults to `default`):
+ The type of normalization to apply to the time embeddings. This can help to improve the performance of the
+ model on tasks with long-range temporal dependencies.
+ resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks.
+ resnet_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use in the group normalization layers of the resnet blocks.
+ attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks.
+ resnet_pre_norm (`bool`, *optional*, defaults to `True`):
+ Whether to use pre-normalization for the resnet blocks.
+ add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks.
+ attention_head_dim (`int`, *optional*, defaults to 1):
+ Dimension of a single attention head. The number of attention heads is determined based on this value and
+ the number of input channels.
+ output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
+
+ Returns:
+ `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
+ in_channels, height, width)`.
+
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default", # default, spatial
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ attn_groups: Optional[int] = None,
+ resnet_pre_norm: bool = True,
+ add_attention: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ ):
+ super().__init__()
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+ self.add_attention = add_attention
+
+ if attn_groups is None:
+ attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None
+
+ # there is always at least one resnet
+ if resnet_time_scale_shift == "spatial":
+ resnets = [
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm="spatial",
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ ]
+ else:
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
+ )
+ attention_head_dim = in_channels
+
+ for _ in range(num_layers):
+ if self.add_attention:
+ attentions.append(
+ Attention(
+ in_channels,
+ heads=in_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=attn_groups,
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+ else:
+ attentions.append(None)
+
+ if resnet_time_scale_shift == "spatial":
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm="spatial",
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+ else:
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if attn is not None:
+ hidden_states = attn(hidden_states, temb=temb)
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class UNetMidBlock2DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads: int = 1,
+ output_scale_factor: float = 1.0,
+ cross_attention_dim: int = 1280,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ upcast_attention: bool = False,
+ attention_type: str = "default",
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # support for variable transformer layers per block
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for i in range(num_layers):
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ in_channels // num_attention_heads,
+ in_channels=in_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ num_attention_heads,
+ in_channels // num_attention_heads,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+
+ return hidden_states
+
+
+class UNetMidBlock2DSimpleCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ cross_attention_dim: int = 1280,
+ skip_time_act: bool = False,
+ only_cross_attention: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+
+ self.attention_head_dim = attention_head_dim
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ self.num_heads = in_channels // self.attention_head_dim
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ processor = (
+ AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
+ )
+
+ attentions.append(
+ Attention(
+ query_dim=in_channels,
+ cross_attention_dim=in_channels,
+ heads=self.num_heads,
+ dim_head=self.attention_head_dim,
+ added_kv_proj_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ bias=True,
+ upcast_softmax=True,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ processor=processor,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+ lora_scale = cross_attention_kwargs.get("scale", 1.0)
+
+ if attention_mask is None:
+ # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
+ mask = None if encoder_hidden_states is None else encoder_attention_mask
+ else:
+ # when attention_mask is defined: we don't even check for encoder_attention_mask.
+ # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
+ # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
+ # then we can simplify this whole if/else block to:
+ # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
+ mask = attention_mask
+
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ # attn
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=mask,
+ **cross_attention_kwargs,
+ )
+
+ # resnet
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+
+ return hidden_states
+
+
+class MidBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ use_linear_projection: bool = False,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = False
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+
+ for i in range(num_layers):
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ lora_scale = 1.0
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
+ for resnet in self.resnets[1:]:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+
+ return hidden_states
+
+
+class AttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ downsample_padding: int = 1,
+ downsample_type: str = "conv",
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+ self.downsample_type = downsample_type
+
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if downsample_type == "conv":
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ elif downsample_type == "resnet":
+ self.downsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ down=True,
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+
+ lora_scale = cross_attention_kwargs.get("scale", 1.0)
+
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ cross_attention_kwargs.update({"scale": lora_scale})
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+ hidden_states = attn(hidden_states, **cross_attention_kwargs)
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ if self.downsample_type == "resnet":
+ hidden_states = downsampler(hidden_states, temb=temb, scale=lora_scale)
+ else:
+ hidden_states = downsampler(hidden_states, scale=lora_scale)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads: int = 1,
+ cross_attention_dim: int = 1280,
+ output_scale_factor: float = 1.0,
+ downsample_padding: int = 1,
+ add_downsample: bool = True,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ attention_type: str = "default",
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ additional_residuals: Optional[torch.FloatTensor] = None,
+ down_block_add_samples: Optional[torch.FloatTensor] = None,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ output_states = ()
+
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+ blocks = list(zip(self.resnets, self.attentions))
+
+ for i, (resnet, attn) in enumerate(blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
+ if i == len(blocks) - 1 and additional_residuals is not None:
+ hidden_states = hidden_states + additional_residuals
+
+ if down_block_add_samples is not None:
+ hidden_states = hidden_states + down_block_add_samples.pop(0)
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states, scale=lora_scale)
+
+ if down_block_add_samples is not None:
+ hidden_states = hidden_states + down_block_add_samples.pop(0) # todo: add before or after
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ downsample_padding: int = 1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0,
+ down_block_add_samples: Optional[torch.FloatTensor] = None,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ output_states = ()
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb, scale=scale)
+
+ if down_block_add_samples is not None:
+ hidden_states = hidden_states + down_block_add_samples.pop(0)
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states, scale=scale)
+
+ if down_block_add_samples is not None:
+ hidden_states = hidden_states + down_block_add_samples.pop(0) # todo: add before or after
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownEncoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ downsample_padding: int = 1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ if resnet_time_scale_shift == "spatial":
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm="spatial",
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+ else:
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=None, scale=scale)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states, scale)
+
+ return hidden_states
+
+
+class AttnDownEncoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ downsample_padding: int = 1,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ if resnet_time_scale_shift == "spatial":
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm="spatial",
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+ else:
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb=None, scale=scale)
+ cross_attention_kwargs = {"scale": scale}
+ hidden_states = attn(hidden_states, **cross_attention_kwargs)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states, scale)
+
+ return hidden_states
+
+
+class AttnSkipDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = np.sqrt(2.0),
+ add_downsample: bool = True,
+ ):
+ super().__init__()
+ self.attentions = nn.ModuleList([])
+ self.resnets = nn.ModuleList([])
+
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(in_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ self.attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=32,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ if add_downsample:
+ self.resnet_down = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ down=True,
+ kernel="fir",
+ )
+ self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
+ else:
+ self.resnet_down = None
+ self.downsamplers = None
+ self.skip_conv = None
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ skip_sample: Optional[torch.FloatTensor] = None,
+ scale: float = 1.0,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb, scale=scale)
+ cross_attention_kwargs = {"scale": scale}
+ hidden_states = attn(hidden_states, **cross_attention_kwargs)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ hidden_states = self.resnet_down(hidden_states, temb, scale=scale)
+ for downsampler in self.downsamplers:
+ skip_sample = downsampler(skip_sample)
+
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states, skip_sample
+
+
+class SkipDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = np.sqrt(2.0),
+ add_downsample: bool = True,
+ downsample_padding: int = 1,
+ ):
+ super().__init__()
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(in_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ if add_downsample:
+ self.resnet_down = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ down=True,
+ kernel="fir",
+ )
+ self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)])
+ self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1))
+ else:
+ self.resnet_down = None
+ self.downsamplers = None
+ self.skip_conv = None
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ skip_sample: Optional[torch.FloatTensor] = None,
+ scale: float = 1.0,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
+ output_states = ()
+
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb, scale)
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ hidden_states = self.resnet_down(hidden_states, temb, scale)
+ for downsampler in self.downsamplers:
+ skip_sample = downsampler(skip_sample)
+
+ hidden_states = self.skip_conv(skip_sample) + hidden_states
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states, skip_sample
+
+
+class ResnetDownsampleBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ skip_time_act: bool = False,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ down=True,
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ output_states = ()
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb, scale)
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states, temb, scale)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class SimpleCrossAttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ cross_attention_dim: int = 1280,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ skip_time_act: bool = False,
+ only_cross_attention: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+
+ resnets = []
+ attentions = []
+
+ self.attention_head_dim = attention_head_dim
+ self.num_heads = out_channels // self.attention_head_dim
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ )
+
+ processor = (
+ AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
+ )
+
+ attentions.append(
+ Attention(
+ query_dim=out_channels,
+ cross_attention_dim=out_channels,
+ heads=self.num_heads,
+ dim_head=attention_head_dim,
+ added_kv_proj_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ bias=True,
+ upcast_softmax=True,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ processor=processor,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ down=True,
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ output_states = ()
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+
+ lora_scale = cross_attention_kwargs.get("scale", 1.0)
+
+ if attention_mask is None:
+ # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
+ mask = None if encoder_hidden_states is None else encoder_attention_mask
+ else:
+ # when attention_mask is defined: we don't even check for encoder_attention_mask.
+ # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
+ # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
+ # then we can simplify this whole if/else block to:
+ # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
+ mask = attention_mask
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=mask,
+ **cross_attention_kwargs,
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=mask,
+ **cross_attention_kwargs,
+ )
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states, temb, scale=lora_scale)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class KDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 4,
+ resnet_eps: float = 1e-5,
+ resnet_act_fn: str = "gelu",
+ resnet_group_size: int = 32,
+ add_downsample: bool = False,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ groups = in_channels // resnet_group_size
+ groups_out = out_channels // resnet_group_size
+
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout=dropout,
+ temb_channels=temb_channels,
+ groups=groups,
+ groups_out=groups_out,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ time_embedding_norm="ada_group",
+ conv_shortcut_bias=False,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ # YiYi's comments- might be able to use FirDownsample2D, look into details later
+ self.downsamplers = nn.ModuleList([KDownsample2D()])
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ output_states = ()
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb, scale)
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states, output_states
+
+
+class KCrossAttnDownBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ cross_attention_dim: int,
+ dropout: float = 0.0,
+ num_layers: int = 4,
+ resnet_group_size: int = 32,
+ add_downsample: bool = True,
+ attention_head_dim: int = 64,
+ add_self_attention: bool = False,
+ resnet_eps: float = 1e-5,
+ resnet_act_fn: str = "gelu",
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ groups = in_channels // resnet_group_size
+ groups_out = out_channels // resnet_group_size
+
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ dropout=dropout,
+ temb_channels=temb_channels,
+ groups=groups,
+ groups_out=groups_out,
+ eps=resnet_eps,
+ non_linearity=resnet_act_fn,
+ time_embedding_norm="ada_group",
+ conv_shortcut_bias=False,
+ )
+ )
+ attentions.append(
+ KAttentionBlock(
+ out_channels,
+ out_channels // attention_head_dim,
+ attention_head_dim,
+ cross_attention_dim=cross_attention_dim,
+ temb_channels=temb_channels,
+ attention_bias=True,
+ add_self_attention=add_self_attention,
+ cross_attention_norm="layer_norm",
+ group_size=resnet_group_size,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.attentions = nn.ModuleList(attentions)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList([KDownsample2D()])
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ output_states = ()
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ emb=temb,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ emb=temb,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+
+ if self.downsamplers is None:
+ output_states += (None,)
+ else:
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ return hidden_states, output_states
+
+
+class AttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: int = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ upsample_type: str = "conv",
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.upsample_type = upsample_type
+
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if upsample_type == "conv":
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ elif upsample_type == "resnet":
+ self.upsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ up=True,
+ )
+ ]
+ )
+ else:
+ self.upsamplers = None
+
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ scale: float = 1.0,
+ ) -> torch.FloatTensor:
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb, scale=scale)
+ cross_attention_kwargs = {"scale": scale}
+ hidden_states = attn(hidden_states, **cross_attention_kwargs)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ if self.upsample_type == "resnet":
+ hidden_states = upsampler(hidden_states, temb=temb, scale=scale)
+ else:
+ hidden_states = upsampler(hidden_states, scale=scale)
+
+ return hidden_states
+
+
+class CrossAttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads: int = 1,
+ cross_attention_dim: int = 1280,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ attention_type: str = "default",
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ return_res_samples: Optional[bool]=False,
+ up_block_add_samples: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+ if return_res_samples:
+ output_states=()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ # FreeU: Only operate on the first two stages
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ if return_res_samples:
+ output_states = output_states + (hidden_states,)
+ if up_block_add_samples is not None:
+ hidden_states = hidden_states + up_block_add_samples.pop(0)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
+ if return_res_samples:
+ output_states = output_states + (hidden_states,)
+ if up_block_add_samples is not None:
+ hidden_states = hidden_states + up_block_add_samples.pop(0)
+
+ if return_res_samples:
+ return hidden_states, output_states
+ else:
+ return hidden_states
+
+class UpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ scale: float = 1.0,
+ return_res_samples: Optional[bool]=False,
+ up_block_add_samples: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+ if return_res_samples:
+ output_states = ()
+
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ # FreeU: Only operate on the first two stages
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb, scale=scale)
+
+ if return_res_samples:
+ output_states = output_states + (hidden_states,)
+ if up_block_add_samples is not None:
+ hidden_states = hidden_states + up_block_add_samples.pop(0) # todo: add before or after
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
+
+ if return_res_samples:
+ output_states = output_states + (hidden_states,)
+ if up_block_add_samples is not None:
+ hidden_states = hidden_states + up_block_add_samples.pop(0) # todo: add before or after
+
+
+ if return_res_samples:
+ return hidden_states, output_states
+ else:
+ return hidden_states
+
+
+class UpDecoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default", # default, spatial
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ temb_channels: Optional[int] = None,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ if resnet_time_scale_shift == "spatial":
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm="spatial",
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+ else:
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
+ ) -> torch.FloatTensor:
+ for resnet in self.resnets:
+ hidden_states = resnet(hidden_states, temb=temb, scale=scale)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class AttnUpDecoderBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ temb_channels: Optional[int] = None,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ if resnet_time_scale_shift == "spatial":
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm="spatial",
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ )
+ )
+ else:
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None,
+ spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
+ ) -> torch.FloatTensor:
+ for resnet, attn in zip(self.resnets, self.attentions):
+ hidden_states = resnet(hidden_states, temb=temb, scale=scale)
+ cross_attention_kwargs = {"scale": scale}
+ hidden_states = attn(hidden_states, temb=temb, **cross_attention_kwargs)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, scale=scale)
+
+ return hidden_states
+
+
+class AttnSkipUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ output_scale_factor: float = np.sqrt(2.0),
+ add_upsample: bool = True,
+ ):
+ super().__init__()
+ self.attentions = nn.ModuleList([])
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(resnet_in_channels + res_skip_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ if attention_head_dim is None:
+ logger.warn(
+ f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}."
+ )
+ attention_head_dim = out_channels
+
+ self.attentions.append(
+ Attention(
+ out_channels,
+ heads=out_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ rescale_output_factor=output_scale_factor,
+ eps=resnet_eps,
+ norm_num_groups=32,
+ residual_connection=True,
+ bias=True,
+ upcast_softmax=True,
+ _from_deprecated_attn_block=True,
+ )
+ )
+
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+ if add_upsample:
+ self.resnet_up = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ up=True,
+ kernel="fir",
+ )
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+ self.skip_norm = torch.nn.GroupNorm(
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
+ )
+ self.act = nn.SiLU()
+ else:
+ self.resnet_up = None
+ self.skip_conv = None
+ self.skip_norm = None
+ self.act = None
+
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ skip_sample=None,
+ scale: float = 1.0,
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb, scale=scale)
+
+ cross_attention_kwargs = {"scale": scale}
+ hidden_states = self.attentions[0](hidden_states, **cross_attention_kwargs)
+
+ if skip_sample is not None:
+ skip_sample = self.upsampler(skip_sample)
+ else:
+ skip_sample = 0
+
+ if self.resnet_up is not None:
+ skip_sample_states = self.skip_norm(hidden_states)
+ skip_sample_states = self.act(skip_sample_states)
+ skip_sample_states = self.skip_conv(skip_sample_states)
+
+ skip_sample = skip_sample + skip_sample_states
+
+ hidden_states = self.resnet_up(hidden_states, temb, scale=scale)
+
+ return hidden_states, skip_sample
+
+
+class SkipUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = np.sqrt(2.0),
+ add_upsample: bool = True,
+ upsample_padding: int = 1,
+ ):
+ super().__init__()
+ self.resnets = nn.ModuleList([])
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ self.resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min((resnet_in_channels + res_skip_channels) // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels)
+ if add_upsample:
+ self.resnet_up = ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=min(out_channels // 4, 32),
+ groups_out=min(out_channels // 4, 32),
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ use_in_shortcut=True,
+ up=True,
+ kernel="fir",
+ )
+ self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
+ self.skip_norm = torch.nn.GroupNorm(
+ num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True
+ )
+ self.act = nn.SiLU()
+ else:
+ self.resnet_up = None
+ self.skip_conv = None
+ self.skip_norm = None
+ self.act = None
+
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ skip_sample=None,
+ scale: float = 1.0,
+ ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb, scale=scale)
+
+ if skip_sample is not None:
+ skip_sample = self.upsampler(skip_sample)
+ else:
+ skip_sample = 0
+
+ if self.resnet_up is not None:
+ skip_sample_states = self.skip_norm(hidden_states)
+ skip_sample_states = self.act(skip_sample_states)
+ skip_sample_states = self.skip_conv(skip_sample_states)
+
+ skip_sample = skip_sample + skip_sample_states
+
+ hidden_states = self.resnet_up(hidden_states, temb, scale=scale)
+
+ return hidden_states, skip_sample
+
+
+class ResnetUpsampleBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ skip_time_act: bool = False,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ up=True,
+ )
+ ]
+ )
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ scale: float = 1.0,
+ ) -> torch.FloatTensor:
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb, scale=scale)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, temb, scale=scale)
+
+ return hidden_states
+
+
+class SimpleCrossAttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attention_head_dim: int = 1,
+ cross_attention_dim: int = 1280,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ skip_time_act: bool = False,
+ only_cross_attention: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.attention_head_dim = attention_head_dim
+
+ self.num_heads = out_channels // self.attention_head_dim
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ )
+ )
+
+ processor = (
+ AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor()
+ )
+
+ attentions.append(
+ Attention(
+ query_dim=out_channels,
+ cross_attention_dim=out_channels,
+ heads=self.num_heads,
+ dim_head=self.attention_head_dim,
+ added_kv_proj_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ bias=True,
+ upcast_softmax=True,
+ only_cross_attention=only_cross_attention,
+ cross_attention_norm=cross_attention_norm,
+ processor=processor,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList(
+ [
+ ResnetBlock2D(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ skip_time_act=skip_time_act,
+ up=True,
+ )
+ ]
+ )
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+
+ lora_scale = cross_attention_kwargs.get("scale", 1.0)
+ if attention_mask is None:
+ # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
+ mask = None if encoder_hidden_states is None else encoder_attention_mask
+ else:
+ # when attention_mask is defined: we don't even check for encoder_attention_mask.
+ # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks.
+ # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask.
+ # then we can simplify this whole if/else block to:
+ # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
+ mask = attention_mask
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # resnet
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=mask,
+ **cross_attention_kwargs,
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=mask,
+ **cross_attention_kwargs,
+ )
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, temb, scale=lora_scale)
+
+ return hidden_states
+
+
+class KUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: int,
+ dropout: float = 0.0,
+ num_layers: int = 5,
+ resnet_eps: float = 1e-5,
+ resnet_act_fn: str = "gelu",
+ resnet_group_size: Optional[int] = 32,
+ add_upsample: bool = True,
+ ):
+ super().__init__()
+ resnets = []
+ k_in_channels = 2 * out_channels
+ k_out_channels = in_channels
+ num_layers = num_layers - 1
+
+ for i in range(num_layers):
+ in_channels = k_in_channels if i == 0 else out_channels
+ groups = in_channels // resnet_group_size
+ groups_out = out_channels // resnet_group_size
+
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=k_out_channels if (i == num_layers - 1) else out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=groups,
+ groups_out=groups_out,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ time_embedding_norm="ada_group",
+ conv_shortcut_bias=False,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([KUpsample2D()])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ scale: float = 1.0,
+ ) -> torch.FloatTensor:
+ res_hidden_states_tuple = res_hidden_states_tuple[-1]
+ if res_hidden_states_tuple is not None:
+ hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb, scale=scale)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class KCrossAttnUpBlock2D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: int,
+ dropout: float = 0.0,
+ num_layers: int = 4,
+ resnet_eps: float = 1e-5,
+ resnet_act_fn: str = "gelu",
+ resnet_group_size: int = 32,
+ attention_head_dim: int = 1, # attention dim_head
+ cross_attention_dim: int = 768,
+ add_upsample: bool = True,
+ upcast_attention: bool = False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ is_first_block = in_channels == out_channels == temb_channels
+ is_middle_block = in_channels != out_channels
+ add_self_attention = True if is_first_block else False
+
+ self.has_cross_attention = True
+ self.attention_head_dim = attention_head_dim
+
+ # in_channels, and out_channels for the block (k-unet)
+ k_in_channels = out_channels if is_first_block else 2 * out_channels
+ k_out_channels = in_channels
+
+ num_layers = num_layers - 1
+
+ for i in range(num_layers):
+ in_channels = k_in_channels if i == 0 else out_channels
+ groups = in_channels // resnet_group_size
+ groups_out = out_channels // resnet_group_size
+
+ if is_middle_block and (i == num_layers - 1):
+ conv_2d_out_channels = k_out_channels
+ else:
+ conv_2d_out_channels = None
+
+ resnets.append(
+ ResnetBlockCondNorm2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ conv_2d_out_channels=conv_2d_out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=groups,
+ groups_out=groups_out,
+ dropout=dropout,
+ non_linearity=resnet_act_fn,
+ time_embedding_norm="ada_group",
+ conv_shortcut_bias=False,
+ )
+ )
+ attentions.append(
+ KAttentionBlock(
+ k_out_channels if (i == num_layers - 1) else out_channels,
+ k_out_channels // attention_head_dim
+ if (i == num_layers - 1)
+ else out_channels // attention_head_dim,
+ attention_head_dim,
+ cross_attention_dim=cross_attention_dim,
+ temb_channels=temb_channels,
+ attention_bias=True,
+ add_self_attention=add_self_attention,
+ cross_attention_norm="layer_norm",
+ upcast_attention=upcast_attention,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.attentions = nn.ModuleList(attentions)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([KUpsample2D()])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ res_hidden_states_tuple = res_hidden_states_tuple[-1]
+ if res_hidden_states_tuple is not None:
+ hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
+
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ emb=temb,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ emb=temb,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+# can potentially later be renamed to `No-feed-forward` attention
+class KAttentionBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ attention_bias (`bool`, *optional*, defaults to `False`):
+ Configure if the attention layers should contain a bias parameter.
+ upcast_attention (`bool`, *optional*, defaults to `False`):
+ Set to `True` to upcast the attention computation to `float32`.
+ temb_channels (`int`, *optional*, defaults to 768):
+ The number of channels in the token embedding.
+ add_self_attention (`bool`, *optional*, defaults to `False`):
+ Set to `True` to add self-attention to the block.
+ cross_attention_norm (`str`, *optional*, defaults to `None`):
+ The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`.
+ group_size (`int`, *optional*, defaults to 32):
+ The number of groups to separate the channels into for group normalization.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout: float = 0.0,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ upcast_attention: bool = False,
+ temb_channels: int = 768, # for ada_group_norm
+ add_self_attention: bool = False,
+ cross_attention_norm: Optional[str] = None,
+ group_size: int = 32,
+ ):
+ super().__init__()
+ self.add_self_attention = add_self_attention
+
+ # 1. Self-Attn
+ if add_self_attention:
+ self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
+ self.attn1 = Attention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=None,
+ cross_attention_norm=None,
+ )
+
+ # 2. Cross-Attn
+ self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size))
+ self.attn2 = Attention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ cross_attention_norm=cross_attention_norm,
+ )
+
+ def _to_3d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor:
+ return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1)
+
+ def _to_4d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor:
+ return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ # TODO: mark emb as non-optional (self.norm2 requires it).
+ # requires assessing impact of change to positional param interface.
+ emb: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
+
+ # 1. Self-Attention
+ if self.add_self_attention:
+ norm_hidden_states = self.norm1(hidden_states, emb)
+
+ height, weight = norm_hidden_states.shape[2:]
+ norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)
+
+ attn_output = self.attn1(
+ norm_hidden_states,
+ encoder_hidden_states=None,
+ attention_mask=attention_mask,
+ **cross_attention_kwargs,
+ )
+ attn_output = self._to_4d(attn_output, height, weight)
+
+ hidden_states = attn_output + hidden_states
+
+ # 2. Cross-Attention/None
+ norm_hidden_states = self.norm2(hidden_states, emb)
+
+ height, weight = norm_hidden_states.shape[2:]
+ norm_hidden_states = self._to_3d(norm_hidden_states, height, weight)
+ attn_output = self.attn2(
+ norm_hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask,
+ **cross_attention_kwargs,
+ )
+ attn_output = self._to_4d(attn_output, height, weight)
+
+ hidden_states = attn_output + hidden_states
+
+ return hidden_states
diff --git a/libs/unet_2d_condition.py b/libs/unet_2d_condition.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8dffb7f9c945e8f3ebd77e2b91eab3b85e08dd0
--- /dev/null
+++ b/libs/unet_2d_condition.py
@@ -0,0 +1,1359 @@
+# Copyright 2024 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, List, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
+from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
+from diffusers.models.activations import get_activation
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ Attention,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+from diffusers.models.embeddings import (
+ GaussianFourierProjection,
+ GLIGENTextBoundingboxProjection,
+ ImageHintTimeEmbedding,
+ ImageProjection,
+ ImageTimeEmbedding,
+ TextImageProjection,
+ TextImageTimeEmbedding,
+ TextTimeEmbedding,
+ TimestepEmbedding,
+ Timesteps,
+)
+from diffusers.models.modeling_utils import ModelMixin
+from .unet_2d_blocks import (
+ get_down_block,
+ get_mid_block,
+ get_up_block,
+)
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class UNet2DConditionOutput(BaseOutput):
+ """
+ The output of [`UNet2DConditionModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
+ """
+
+ sample: torch.FloatTensor = None
+
+
+class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
+ r"""
+ A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
+ shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+
+ Parameters:
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
+ Height and width of input/output sample.
+ in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample.
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
+ center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample.
+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
+ Whether to flip the sin to cos in the time embedding.
+ freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
+ The tuple of downsample blocks to use.
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or
+ `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped.
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
+ The tuple of upsample blocks to use.
+ only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`):
+ Whether to include self-attention in the basic transformer blocks, see
+ [`~models.attention.BasicTransformerBlock`].
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each block.
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
+ downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution.
+ mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
+ If `None`, normalization and activation layers is skipped in post-processing.
+ norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
+ The dimension of the cross attention features.
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None):
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling
+ blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for
+ [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`],
+ [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`].
+ encoder_hid_dim (`int`, *optional*, defaults to None):
+ If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim`
+ dimension to `cross_attention_dim`.
+ encoder_hid_dim_type (`str`, *optional*, defaults to `None`):
+ If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text
+ embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`.
+ attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
+ num_attention_heads (`int`, *optional*):
+ The number of attention heads. If not defined, defaults to `attention_head_dim`
+ resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config
+ for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`.
+ class_embed_type (`str`, *optional*, defaults to `None`):
+ The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`,
+ `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`.
+ addition_embed_type (`str`, *optional*, defaults to `None`):
+ Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or
+ "text". "text" will use the `TextTimeEmbedding` layer.
+ addition_time_embed_dim: (`int`, *optional*, defaults to `None`):
+ Dimension for the timestep embeddings.
+ num_class_embeds (`int`, *optional*, defaults to `None`):
+ Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing
+ class conditioning with `class_embed_type` equal to `None`.
+ time_embedding_type (`str`, *optional*, defaults to `positional`):
+ The type of position embedding to use for timesteps. Choose from `positional` or `fourier`.
+ time_embedding_dim (`int`, *optional*, defaults to `None`):
+ An optional override for the dimension of the projected time embedding.
+ time_embedding_act_fn (`str`, *optional*, defaults to `None`):
+ Optional activation function to use only once on the time embeddings before they are passed to the rest of
+ the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`.
+ timestep_post_act (`str`, *optional*, defaults to `None`):
+ The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`.
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
+ The dimension of `cond_proj` layer in the timestep embedding.
+ conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`,
+ *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`,
+ *optional*): The dimension of the `class_labels` input when
+ `class_embed_type="projection"`. Required when `class_embed_type="projection"`.
+ class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time
+ embeddings with the class embeddings.
+ mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`):
+ Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If
+ `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the
+ `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False`
+ otherwise.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "CrossAttnDownBlock2D",
+ "DownBlock2D",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn",
+ up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: Union[int, Tuple[int]] = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ dropout: float = 0.0,
+ act_fn: str = "silu",
+ norm_num_groups: Optional[int] = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: Union[int, Tuple[int]] = 1280,
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
+ reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ addition_embed_type: Optional[str] = None,
+ addition_time_embed_dim: Optional[int] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ resnet_skip_time_act: bool = False,
+ resnet_out_scale_factor: float = 1.0,
+ time_embedding_type: str = "positional",
+ time_embedding_dim: Optional[int] = None,
+ time_embedding_act_fn: Optional[str] = None,
+ timestep_post_act: Optional[str] = None,
+ time_cond_proj_dim: Optional[int] = None,
+ conv_in_kernel: int = 3,
+ conv_out_kernel: int = 3,
+ projection_class_embeddings_input_dim: Optional[int] = None,
+ attention_type: str = "default",
+ class_embeddings_concat: bool = False,
+ mid_block_only_cross_attention: Optional[bool] = None,
+ cross_attention_norm: Optional[str] = None,
+ addition_embed_type_num_heads: int = 64,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ if num_attention_heads is not None:
+ raise ValueError(
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19."
+ )
+
+ # If `num_attention_heads` is not defined (which is the case for most models)
+ # it will default to `attention_head_dim`. This looks weird upon first reading it and it is.
+ # The reason for this behavior is to correct for incorrectly named variables that were introduced
+ # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131
+ # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking
+ # which is why we correct for the naming here.
+ num_attention_heads = num_attention_heads or attention_head_dim
+
+ # Check inputs
+ self._check_config(
+ down_block_types=down_block_types,
+ up_block_types=up_block_types,
+ only_cross_attention=only_cross_attention,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ transformer_layers_per_block=transformer_layers_per_block,
+ reverse_transformer_layers_per_block=reverse_transformer_layers_per_block,
+ attention_head_dim=attention_head_dim,
+ num_attention_heads=num_attention_heads,
+ )
+
+ # input
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ time_embed_dim, timestep_input_dim = self._set_time_proj(
+ time_embedding_type,
+ block_out_channels=block_out_channels,
+ flip_sin_to_cos=flip_sin_to_cos,
+ freq_shift=freq_shift,
+ time_embedding_dim=time_embedding_dim,
+ )
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ post_act_fn=timestep_post_act,
+ cond_proj_dim=time_cond_proj_dim,
+ )
+
+ self._set_encoder_hid_proj(
+ encoder_hid_dim_type,
+ cross_attention_dim=cross_attention_dim,
+ encoder_hid_dim=encoder_hid_dim,
+ )
+
+ # class embedding
+ self._set_class_embedding(
+ class_embed_type,
+ act_fn=act_fn,
+ num_class_embeds=num_class_embeds,
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
+ time_embed_dim=time_embed_dim,
+ timestep_input_dim=timestep_input_dim,
+ )
+
+ self._set_add_embedding(
+ addition_embed_type,
+ addition_embed_type_num_heads=addition_embed_type_num_heads,
+ addition_time_embed_dim=addition_time_embed_dim,
+ cross_attention_dim=cross_attention_dim,
+ encoder_hid_dim=encoder_hid_dim,
+ flip_sin_to_cos=flip_sin_to_cos,
+ freq_shift=freq_shift,
+ projection_class_embeddings_input_dim=projection_class_embeddings_input_dim,
+ time_embed_dim=time_embed_dim,
+ )
+
+ if time_embedding_act_fn is None:
+ self.time_embed_act = None
+ else:
+ self.time_embed_act = get_activation(time_embedding_act_fn)
+
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = only_cross_attention
+
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+ if mid_block_only_cross_attention is None:
+ mid_block_only_cross_attention = False
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+ if isinstance(cross_attention_dim, int):
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
+
+ if isinstance(layers_per_block, int):
+ layers_per_block = [layers_per_block] * len(down_block_types)
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
+
+ if class_embeddings_concat:
+ # The time embeddings are concatenated with the class embeddings. The dimension of the
+ # time embeddings passed to the down, middle, and up blocks is twice the dimension of the
+ # regular time embeddings
+ blocks_time_embed_dim = time_embed_dim * 2
+ else:
+ blocks_time_embed_dim = time_embed_dim
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block[i],
+ transformer_layers_per_block=transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim[i],
+ num_attention_heads=num_attention_heads[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ dropout=dropout,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ self.mid_block = get_mid_block(
+ mid_block_type,
+ temb_channels=blocks_time_embed_dim,
+ in_channels=block_out_channels[-1],
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ output_scale_factor=mid_block_scale_factor,
+ transformer_layers_per_block=transformer_layers_per_block[-1],
+ num_attention_heads=num_attention_heads[-1],
+ cross_attention_dim=cross_attention_dim[-1],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ mid_block_only_cross_attention=mid_block_only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[-1],
+ dropout=dropout,
+ )
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+ reversed_layers_per_block = list(reversed(layers_per_block))
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
+ reversed_transformer_layers_per_block = (
+ list(reversed(transformer_layers_per_block))
+ if reverse_transformer_layers_per_block is None
+ else reverse_transformer_layers_per_block
+ )
+ only_cross_attention = list(reversed(only_cross_attention))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=reversed_layers_per_block[i] + 1,
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=blocks_time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resolution_idx=i,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=reversed_cross_attention_dim[i],
+ num_attention_heads=reversed_num_attention_heads[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ attention_type=attention_type,
+ resnet_skip_time_act=resnet_skip_time_act,
+ resnet_out_scale_factor=resnet_out_scale_factor,
+ cross_attention_norm=cross_attention_norm,
+ attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel,
+ dropout=dropout,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if norm_num_groups is not None:
+ self.conv_norm_out = nn.GroupNorm(
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
+ )
+
+ self.conv_act = get_activation(act_fn)
+
+ else:
+ self.conv_norm_out = None
+ self.conv_act = None
+
+ conv_out_padding = (conv_out_kernel - 1) // 2
+ self.conv_out = nn.Conv2d(
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
+ )
+
+ self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim)
+
+ def _check_config(
+ self,
+ down_block_types: Tuple[str],
+ up_block_types: Tuple[str],
+ only_cross_attention: Union[bool, Tuple[bool]],
+ block_out_channels: Tuple[int],
+ layers_per_block: Union[int, Tuple[int]],
+ cross_attention_dim: Union[int, Tuple[int]],
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]],
+ reverse_transformer_layers_per_block: bool,
+ attention_head_dim: int,
+ num_attention_heads: Optional[Union[int, Tuple[int]]],
+ ):
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
+ )
+ if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None:
+ for layer_number_per_block in transformer_layers_per_block:
+ if isinstance(layer_number_per_block, list):
+ raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.")
+
+ def _set_time_proj(
+ self,
+ time_embedding_type: str,
+ block_out_channels: int,
+ flip_sin_to_cos: bool,
+ freq_shift: float,
+ time_embedding_dim: int,
+ ) -> Tuple[int, int]:
+ if time_embedding_type == "fourier":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 2
+ if time_embed_dim % 2 != 0:
+ raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.")
+ self.time_proj = GaussianFourierProjection(
+ time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos
+ )
+ timestep_input_dim = time_embed_dim
+ elif time_embedding_type == "positional":
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
+
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+ else:
+ raise ValueError(
+ f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`."
+ )
+
+ return time_embed_dim, timestep_input_dim
+
+ def _set_encoder_hid_proj(
+ self,
+ encoder_hid_dim_type: Optional[str],
+ cross_attention_dim: Union[int, Tuple[int]],
+ encoder_hid_dim: Optional[int],
+ ):
+ if encoder_hid_dim_type is None and encoder_hid_dim is not None:
+ encoder_hid_dim_type = "text_proj"
+ self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type)
+ logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.")
+
+ if encoder_hid_dim is None and encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}."
+ )
+
+ if encoder_hid_dim_type == "text_proj":
+ self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim)
+ elif encoder_hid_dim_type == "text_image_proj":
+ # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
+ self.encoder_hid_proj = TextImageProjection(
+ text_embed_dim=encoder_hid_dim,
+ image_embed_dim=cross_attention_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2
+ self.encoder_hid_proj = ImageProjection(
+ image_embed_dim=encoder_hid_dim,
+ cross_attention_dim=cross_attention_dim,
+ )
+ elif encoder_hid_dim_type is not None:
+ raise ValueError(
+ f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
+ )
+ else:
+ self.encoder_hid_proj = None
+
+ def _set_class_embedding(
+ self,
+ class_embed_type: Optional[str],
+ act_fn: str,
+ num_class_embeds: Optional[int],
+ projection_class_embeddings_input_dim: Optional[int],
+ time_embed_dim: int,
+ timestep_input_dim: int,
+ ):
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ elif class_embed_type == "projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ # The projection `class_embed_type` is the same as the timestep `class_embed_type` except
+ # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings
+ # 2. it projects from an arbitrary input dimension.
+ #
+ # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations.
+ # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings.
+ # As a result, `TimestepEmbedding` can be passed arbitrary vectors.
+ self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif class_embed_type == "simple_projection":
+ if projection_class_embeddings_input_dim is None:
+ raise ValueError(
+ "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set"
+ )
+ self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ def _set_add_embedding(
+ self,
+ addition_embed_type: str,
+ addition_embed_type_num_heads: int,
+ addition_time_embed_dim: Optional[int],
+ flip_sin_to_cos: bool,
+ freq_shift: float,
+ cross_attention_dim: Optional[int],
+ encoder_hid_dim: Optional[int],
+ projection_class_embeddings_input_dim: Optional[int],
+ time_embed_dim: int,
+ ):
+ if addition_embed_type == "text":
+ if encoder_hid_dim is not None:
+ text_time_embedding_from_dim = encoder_hid_dim
+ else:
+ text_time_embedding_from_dim = cross_attention_dim
+
+ self.add_embedding = TextTimeEmbedding(
+ text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads
+ )
+ elif addition_embed_type == "text_image":
+ # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
+ # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
+ # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
+ self.add_embedding = TextImageTimeEmbedding(
+ text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
+ )
+ elif addition_embed_type == "text_time":
+ self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift)
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
+ elif addition_embed_type == "image":
+ # Kandinsky 2.2
+ self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type == "image_hint":
+ # Kandinsky 2.2 ControlNet
+ self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
+ elif addition_embed_type is not None:
+ raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.")
+
+ def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
+ if attention_type in ["gated", "gated-text-image"]:
+ positive_len = 768
+ if isinstance(cross_attention_dim, int):
+ positive_len = cross_attention_dim
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
+ positive_len = cross_attention_dim[0]
+
+ feature_type = "text-only" if attention_type == "gated" else "text-image"
+ self.position_net = GLIGENTextBoundingboxProjection(
+ positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type
+ )
+
+ @property
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"))
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ def set_default_attn_processor(self):
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor)
+
+ def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module splits the input tensor in slices to compute attention in
+ several steps. This is useful for saving some memory in exchange for a small decrease in speed.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If
+ `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_sliceable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_sliceable_dims(module)
+
+ num_sliceable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_sliceable_layers * [1]
+
+ slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if hasattr(module, "gradient_checkpointing"):
+ module.gradient_checkpointing = value
+
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ for i, upsample_block in enumerate(self.up_blocks):
+ setattr(upsample_block, "s1", s1)
+ setattr(upsample_block, "s2", s2)
+ setattr(upsample_block, "b1", b1)
+ setattr(upsample_block, "b2", b2)
+
+ def disable_freeu(self):
+ """Disables the FreeU mechanism."""
+ freeu_keys = {"s1", "s2", "b1", "b2"}
+ for i, upsample_block in enumerate(self.up_blocks):
+ for k in freeu_keys:
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
+ setattr(upsample_block, k, None)
+
+ def fuse_qkv_projections(self):
+ """
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
+
+
+
+ This API is 🧪 experimental.
+
+
+ """
+ self.original_attn_processors = None
+
+ for _, attn_processor in self.attn_processors.items():
+ if "Added" in str(attn_processor.__class__.__name__):
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
+
+ self.original_attn_processors = self.attn_processors
+
+ for module in self.modules():
+ if isinstance(module, Attention):
+ module.fuse_projections(fuse=True)
+
+ def unfuse_qkv_projections(self):
+ """Disables the fused QKV projection if enabled.
+
+
+
+ This API is 🧪 experimental.
+
+
+
+ """
+ if self.original_attn_processors is not None:
+ self.set_attn_processor(self.original_attn_processors)
+
+ def unload_lora(self):
+ """Unloads LoRA weights."""
+ deprecate(
+ "unload_lora",
+ "0.28.0",
+ "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().",
+ )
+ for module in self.modules():
+ if hasattr(module, "set_lora_layer"):
+ module.set_lora_layer(None)
+
+ def get_time_embed(
+ self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int]
+ ) -> Optional[torch.Tensor]:
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=sample.dtype)
+ return t_emb
+
+ def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]:
+ class_emb = None
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ # `Timesteps` does not contain any weights and will always return f32 tensors
+ # there might be better ways to encapsulate this.
+ class_labels = class_labels.to(dtype=sample.dtype)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
+ return class_emb
+
+ def get_aug_embed(
+ self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
+ ) -> Optional[torch.Tensor]:
+ aug_emb = None
+ if self.config.addition_embed_type == "text":
+ aug_emb = self.add_embedding(encoder_hidden_states)
+ elif self.config.addition_embed_type == "text_image":
+ # Kandinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+
+ image_embs = added_cond_kwargs.get("image_embeds")
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
+ aug_emb = self.add_embedding(text_embs, image_embs)
+ elif self.config.addition_embed_type == "text_time":
+ # SDXL - style
+ if "text_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
+ )
+ text_embeds = added_cond_kwargs.get("text_embeds")
+ if "time_ids" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
+ )
+ time_ids = added_cond_kwargs.get("time_ids")
+ time_embeds = self.add_time_proj(time_ids.flatten())
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
+ add_embeds = add_embeds.to(emb.dtype)
+ aug_emb = self.add_embedding(add_embeds)
+ elif self.config.addition_embed_type == "image":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ aug_emb = self.add_embedding(image_embs)
+ elif self.config.addition_embed_type == "image_hint":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
+ )
+ image_embs = added_cond_kwargs.get("image_embeds")
+ hint = added_cond_kwargs.get("hint")
+ aug_emb = self.add_embedding(image_embs, hint)
+ return aug_emb
+
+ def process_encoder_hidden_states(
+ self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any]
+ ) -> torch.Tensor:
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
+ # Kadinsky 2.1 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
+ # Kandinsky 2.2 - style
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ image_embeds = self.encoder_hid_proj(image_embeds)
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
+ return encoder_hidden_states
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ class_labels: Optional[torch.Tensor] = None,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
+ up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
+ features_adapter: Optional[torch.Tensor] = None,
+ ) -> Union[UNet2DConditionOutput, Tuple]:
+ r"""
+ The [`UNet2DConditionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch, channel, height, width)`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ class_labels (`torch.Tensor`, *optional*, defaults to `None`):
+ Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
+ A tensor that if specified is added to the residual of the middle unet block.
+ encoder_attention_mask (`torch.Tensor`):
+ A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If
+ `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias,
+ which adds large negative values to the attention scores corresponding to "discard" tokens.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the [`AttnProcessor`].
+ added_cond_kwargs: (`dict`, *optional*):
+ A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that
+ are passed along to the UNet blocks.
+ down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
+ additional residuals to be added to UNet long skip connections from down blocks to up blocks for
+ example from ControlNet side model(s)
+ mid_block_additional_residual (`torch.Tensor`, *optional*):
+ additional residual to be added to UNet mid block output, for example from ControlNet side model
+ down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*):
+ additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s)
+ features_adapter (`torch.FloatTensor`, *optional*):
+ (batch, channels, num_frames, height, width) adapter features tensor
+
+ Returns:
+ [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
+ a `tuple` is returned where the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ for dim in sample.shape[-2:]:
+ if dim % default_overall_up_factor != 0:
+ # Forward upsample size to force interpolation output size.
+ forward_upsample_size = True
+ break
+
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 0. center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # 1. time
+ t_emb = self.get_time_embed(sample=sample, timestep=timestep)
+ emb = self.time_embedding(t_emb, timestep_cond)
+ aug_emb = None
+
+ class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
+ if class_emb is not None:
+ if self.config.class_embeddings_concat:
+ emb = torch.cat([emb, class_emb], dim=-1)
+ else:
+ emb = emb + class_emb
+
+ aug_emb = self.get_aug_embed(
+ emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
+ )
+ if self.config.addition_embed_type == "image_hint":
+ aug_emb, hint = aug_emb
+ sample = torch.cat([sample, hint], dim=1)
+
+ emb = emb + aug_emb if aug_emb is not None else emb
+
+ if self.time_embed_act is not None:
+ emb = self.time_embed_act(emb)
+
+ encoder_hidden_states = self.process_encoder_hidden_states(
+ encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs
+ )
+
+ # 2. pre-process
+ sample = self.conv_in(sample)
+
+ # 2.5 GLIGEN position net
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
+ cross_attention_kwargs = cross_attention_kwargs.copy()
+ gligen_args = cross_attention_kwargs.pop("gligen")
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
+
+ # 3. down
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+ if USE_PEFT_BACKEND:
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
+ scale_lora_layers(self, lora_scale)
+
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
+ is_adapter = down_intrablock_additional_residuals is not None
+ # maintain backward compatibility for legacy usage, where
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
+ # but can only use one or the other
+ is_brushnet = down_block_add_samples is not None and mid_block_add_sample is not None and up_block_add_samples is not None
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
+ deprecate(
+ "T2I should not use down_block_additional_residuals",
+ "1.3.0",
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
+ standard_warn=False,
+ )
+ down_intrablock_additional_residuals = down_block_additional_residuals
+ is_adapter = True
+
+ down_block_res_samples = (sample,)
+
+ if is_brushnet:
+ sample = sample + down_block_add_samples.pop(0)
+
+ adapter_idx = 0
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlock2D
+ additional_residuals = {}
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
+
+ if is_brushnet and len(down_block_add_samples)>0:
+ additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
+ for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))]
+
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+ else:
+ additional_residuals = {}
+ if is_brushnet and len(down_block_add_samples)>0:
+ additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
+ for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))]
+
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale, **additional_residuals)
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ if features_adapter is not None:
+ sample += features_adapter[adapter_idx]
+ adapter_idx += 1
+
+ down_block_res_samples += res_samples
+
+ if features_adapter is not None:
+ assert len(features_adapter) == adapter_idx, "Wrong features_adapter"
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ encoder_attention_mask=encoder_attention_mask,
+ )
+ else:
+ sample = self.mid_block(sample, emb)
+
+ # To support T2I-Adapter-XL
+ if (
+ is_adapter
+ and len(down_intrablock_additional_residuals) > 0
+ and sample.shape == down_intrablock_additional_residuals[0].shape
+ ):
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ if is_controlnet:
+ sample = sample + mid_block_additional_residual
+
+ if is_brushnet:
+ sample = sample + mid_block_add_sample
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ additional_residuals = {}
+ if is_brushnet and len(up_block_add_samples)>0:
+ additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
+ for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))]
+
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ **additional_residuals,
+ )
+ else:
+ additional_residuals = {}
+ if is_brushnet and len(up_block_add_samples)>0:
+ additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
+ for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))]
+
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ upsample_size=upsample_size,
+ scale=lora_scale,
+ **additional_residuals,
+ )
+
+ # 6. post-process
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+
+ if USE_PEFT_BACKEND:
+ # remove `lora_scale` from each PEFT layer
+ unscale_lora_layers(self, lora_scale)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet2DConditionOutput(sample=sample)
diff --git a/libs/unet_3d_blocks.py b/libs/unet_3d_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b6ce7413919d80a225a1f86a3959c5c022ed411
--- /dev/null
+++ b/libs/unet_3d_blocks.py
@@ -0,0 +1,2463 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+from torch import nn
+
+from diffusers.utils import is_torch_version
+from diffusers.utils.torch_utils import apply_freeu
+from diffusers.models.attention import Attention
+from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel
+from diffusers.models.resnet import (
+ Downsample2D,
+ ResnetBlock2D,
+ SpatioTemporalResBlock,
+ TemporalConvLayer,
+ Upsample2D,
+)
+from diffusers.models.transformers.transformer_2d import Transformer2DModel
+from diffusers.models.transformers.transformer_temporal import (
+ TransformerSpatioTemporalModel,
+)
+from libs.transformer_temporal import TransformerTemporalModel
+
+def get_down_block(
+ down_block_type: str,
+ num_layers: int,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ add_downsample: bool,
+ resnet_eps: float,
+ resnet_act_fn: str,
+ num_attention_heads: int,
+ resnet_groups: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ downsample_padding: Optional[int] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = True,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ temporal_num_attention_heads: int = 8,
+ temporal_max_seq_length: int = 32,
+ transformer_layers_per_block: int = 1,
+) -> Union[
+ "DownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlockMotion",
+ "CrossAttnDownBlockMotion",
+ "DownBlockSpatioTemporal",
+ "CrossAttnDownBlockSpatioTemporal",
+]:
+ if down_block_type == "DownBlock3D":
+ return DownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "CrossAttnDownBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
+ return CrossAttnDownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ if down_block_type == "DownBlockMotion":
+ return DownBlockMotion(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temporal_num_attention_heads=temporal_num_attention_heads,
+ temporal_max_seq_length=temporal_max_seq_length,
+ )
+ elif down_block_type == "CrossAttnDownBlockMotion":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion")
+ return CrossAttnDownBlockMotion(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ temporal_num_attention_heads=temporal_num_attention_heads,
+ temporal_max_seq_length=temporal_max_seq_length,
+ )
+ elif down_block_type == "DownBlockSpatioTemporal":
+ # added for SDV
+ return DownBlockSpatioTemporal(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ )
+ elif down_block_type == "CrossAttnDownBlockSpatioTemporal":
+ # added for SDV
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal")
+ return CrossAttnDownBlockSpatioTemporal(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ num_layers=num_layers,
+ transformer_layers_per_block=transformer_layers_per_block,
+ add_downsample=add_downsample,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ )
+
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_up_block(
+ up_block_type: str,
+ num_layers: int,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ add_upsample: bool,
+ resnet_eps: float,
+ resnet_act_fn: str,
+ num_attention_heads: int,
+ resolution_idx: Optional[int] = None,
+ resnet_groups: Optional[int] = None,
+ cross_attention_dim: Optional[int] = None,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = True,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ temporal_num_attention_heads: int = 8,
+ temporal_cross_attention_dim: Optional[int] = None,
+ temporal_max_seq_length: int = 32,
+ transformer_layers_per_block: int = 1,
+ dropout: float = 0.0,
+) -> Union[
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "UpBlockMotion",
+ "CrossAttnUpBlockMotion",
+ "UpBlockSpatioTemporal",
+ "CrossAttnUpBlockSpatioTemporal",
+]:
+ if up_block_type == "UpBlock3D":
+ return UpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resolution_idx=resolution_idx,
+ )
+ elif up_block_type == "CrossAttnUpBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
+ return CrossAttnUpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resolution_idx=resolution_idx,
+ )
+ if up_block_type == "UpBlockMotion":
+ return UpBlockMotion(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resolution_idx=resolution_idx,
+ temporal_num_attention_heads=temporal_num_attention_heads,
+ temporal_max_seq_length=temporal_max_seq_length,
+ )
+ elif up_block_type == "CrossAttnUpBlockMotion":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion")
+ return CrossAttnUpBlockMotion(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ resolution_idx=resolution_idx,
+ temporal_num_attention_heads=temporal_num_attention_heads,
+ temporal_max_seq_length=temporal_max_seq_length,
+ )
+ elif up_block_type == "UpBlockSpatioTemporal":
+ # added for SDV
+ return UpBlockSpatioTemporal(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ resolution_idx=resolution_idx,
+ add_upsample=add_upsample,
+ )
+ elif up_block_type == "CrossAttnUpBlockSpatioTemporal":
+ # added for SDV
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal")
+ return CrossAttnUpBlockSpatioTemporal(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ num_layers=num_layers,
+ transformer_layers_per_block=transformer_layers_per_block,
+ add_upsample=add_upsample,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads,
+ resolution_idx=resolution_idx,
+ )
+
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlock3DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads: int = 1,
+ output_scale_factor: float = 1.0,
+ cross_attention_dim: int = 1280,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = True,
+ upcast_attention: bool = False,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ temp_convs = [
+ TemporalConvLayer(
+ in_channels,
+ in_channels,
+ dropout=0.1,
+ norm_num_groups=resnet_groups,
+ )
+ ]
+ attentions = []
+ temp_attentions = []
+
+ for _ in range(num_layers):
+ attentions.append(
+ Transformer2DModel(
+ in_channels // num_attention_heads,
+ num_attention_heads,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ )
+ )
+ temp_attentions.append(
+ TransformerTemporalModel(
+ in_channels // num_attention_heads,
+ num_attention_heads,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ temp_convs.append(
+ TemporalConvLayer(
+ in_channels,
+ in_channels,
+ dropout=0.1,
+ norm_num_groups=resnet_groups,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.temp_convs = nn.ModuleList(temp_convs)
+ self.attentions = nn.ModuleList(attentions)
+ self.temp_attentions = nn.ModuleList(temp_attentions)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ num_frames: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ) -> torch.FloatTensor:
+ hidden_states = self.resnets[0](hidden_states, temb)
+ hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
+ for attn, temp_attn, resnet, temp_conv in zip(
+ self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
+ ):
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+ hidden_states = temp_attn(
+ hidden_states,
+ num_frames=num_frames,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
+
+ return hidden_states
+
+
+class CrossAttnDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads: int = 1,
+ cross_attention_dim: int = 1280,
+ output_scale_factor: float = 1.0,
+ downsample_padding: int = 1,
+ add_downsample: bool = True,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+ temp_attentions = []
+ temp_convs = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ temp_convs.append(
+ TemporalConvLayer(
+ out_channels,
+ out_channels,
+ dropout=0.1,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ attentions.append(
+ Transformer2DModel(
+ out_channels // num_attention_heads,
+ num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ )
+ )
+ temp_attentions.append(
+ TransformerTemporalModel(
+ out_channels // num_attention_heads,
+ num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ self.resnets = nn.ModuleList(resnets)
+ self.temp_convs = nn.ModuleList(temp_convs)
+ self.attentions = nn.ModuleList(attentions)
+ self.temp_attentions = nn.ModuleList(temp_attentions)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=downsample_padding,
+ name="op",
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ num_frames: int = 1,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ # TODO(Patrick, William) - attention mask is not used
+ output_states = ()
+
+ for resnet, temp_conv, attn, temp_attn in zip(
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
+ ):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+ hidden_states = temp_attn(
+ hidden_states,
+ num_frames=num_frames,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ downsample_padding: int = 1,
+ ):
+ super().__init__()
+ resnets = []
+ temp_convs = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ temp_convs.append(
+ TemporalConvLayer(
+ out_channels,
+ out_channels,
+ dropout=0.1,
+ norm_num_groups=resnet_groups,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.temp_convs = nn.ModuleList(temp_convs)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=downsample_padding,
+ name="op",
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ num_frames: int = 1,
+ ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ output_states = ()
+
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnUpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads: int = 1,
+ cross_attention_dim: int = 1280,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ resolution_idx: Optional[int] = None,
+ ):
+ super().__init__()
+ resnets = []
+ temp_convs = []
+ attentions = []
+ temp_attentions = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ temp_convs.append(
+ TemporalConvLayer(
+ out_channels,
+ out_channels,
+ dropout=0.1,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ attentions.append(
+ Transformer2DModel(
+ out_channels // num_attention_heads,
+ num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ )
+ )
+ temp_attentions.append(
+ TransformerTemporalModel(
+ out_channels // num_attention_heads,
+ num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ self.resnets = nn.ModuleList(resnets)
+ self.temp_convs = nn.ModuleList(temp_convs)
+ self.attentions = nn.ModuleList(attentions)
+ self.temp_attentions = nn.ModuleList(temp_attentions)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ num_frames: int = 1,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ ) -> torch.FloatTensor:
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+
+ # TODO(Patrick, William) - attention mask is not used
+ for resnet, temp_conv, attn, temp_attn in zip(
+ self.resnets, self.temp_convs, self.attentions, self.temp_attentions
+ ):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ # FreeU: Only operate on the first two stages
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+ hidden_states = temp_attn(
+ hidden_states,
+ num_frames=num_frames,
+ cross_attention_kwargs=cross_attention_kwargs,
+ return_dict=False,
+ )[0]
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class UpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ resolution_idx: Optional[int] = None,
+ ):
+ super().__init__()
+ resnets = []
+ temp_convs = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ temp_convs.append(
+ TemporalConvLayer(
+ out_channels,
+ out_channels,
+ dropout=0.1,
+ norm_num_groups=resnet_groups,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.temp_convs = nn.ModuleList(temp_convs)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ upsample_size: Optional[int] = None,
+ num_frames: int = 1,
+ ) -> torch.FloatTensor:
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ # FreeU: Only operate on the first two stages
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class DownBlockMotion(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_downsample: bool = True,
+ downsample_padding: int = 1,
+ temporal_num_attention_heads: int = 1,
+ temporal_cross_attention_dim: Optional[int] = None,
+ temporal_max_seq_length: int = 32,
+ ):
+ super().__init__()
+ resnets = []
+ motion_modules = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ motion_modules.append(
+ TransformerTemporalModel(
+ num_attention_heads=temporal_num_attention_heads,
+ in_channels=out_channels,
+ norm_num_groups=resnet_groups,
+ cross_attention_dim=temporal_cross_attention_dim,
+ attention_bias=False,
+ activation_fn="geglu",
+ positional_embeddings="sinusoidal",
+ num_positional_embeddings=temporal_max_seq_length,
+ attention_head_dim=out_channels // temporal_num_attention_heads,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.motion_modules = nn.ModuleList(motion_modules)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=downsample_padding,
+ name="op",
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ down_block_add_samples: Optional[torch.FloatTensor] = None,
+ temb: Optional[torch.FloatTensor] = None,
+ scale: float = 1.0,
+ num_frames: int = 1,
+ ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ output_states = ()
+
+ blocks = zip(self.resnets, self.motion_modules)
+ for resnet, motion_module in blocks:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+
+ if down_block_add_samples is not None:
+ hidden_states = hidden_states + down_block_add_samples.pop(0)
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(motion_module),
+ hidden_states.requires_grad_(),
+ temb,
+ num_frames,
+ **ckpt_kwargs,
+ )
+
+ else:
+ hidden_states = resnet(hidden_states, temb, scale=scale)
+ if down_block_add_samples is not None:
+ hidden_states = hidden_states + down_block_add_samples.pop(0)
+ hidden_states = motion_module(hidden_states, num_frames=num_frames)
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states, scale=scale)
+
+ if down_block_add_samples is not None:
+ hidden_states = hidden_states + down_block_add_samples.pop(0)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnDownBlockMotion(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads: int = 1,
+ cross_attention_dim: int = 1280,
+ output_scale_factor: float = 1.0,
+ downsample_padding: int = 1,
+ add_downsample: bool = True,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ attention_type: str = "default",
+ temporal_cross_attention_dim: Optional[int] = None,
+ temporal_num_attention_heads: int = 8,
+ temporal_max_seq_length: int = 32,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+ motion_modules = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+
+ motion_modules.append(
+ TransformerTemporalModel(
+ num_attention_heads=temporal_num_attention_heads,
+ in_channels=out_channels,
+ norm_num_groups=resnet_groups,
+ cross_attention_dim=temporal_cross_attention_dim,
+ attention_bias=False,
+ activation_fn="geglu",
+ positional_embeddings="sinusoidal",
+ num_positional_embeddings=temporal_max_seq_length,
+ attention_head_dim=out_channels // temporal_num_attention_heads,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+ self.motion_modules = nn.ModuleList(motion_modules)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=downsample_padding,
+ name="op",
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ num_frames: int = 1,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ additional_residuals: Optional[torch.FloatTensor] = None,
+ down_block_add_samples: Optional[torch.FloatTensor] = None,
+ ):
+ output_states = ()
+
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+
+ blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
+ for i, (resnet, attn, motion_module) in enumerate(blocks):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
+ if i == len(blocks) - 1 and additional_residuals is not None:
+ hidden_states = hidden_states + additional_residuals
+ if down_block_add_samples is not None:
+ hidden_states = hidden_states + down_block_add_samples.pop(0)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(motion_module),
+ hidden_states.requires_grad_(),
+ temb,
+ num_frames,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ # apply additional residuals to the output of the last pair of resnet and attention blocks
+ if i == len(blocks) - 1 and additional_residuals is not None:
+ hidden_states = hidden_states + additional_residuals
+ if down_block_add_samples is not None:
+ hidden_states = hidden_states + down_block_add_samples.pop(0)
+ hidden_states = motion_module(
+ hidden_states,
+ num_frames=num_frames,
+ )
+
+ # # apply additional residuals to the output of the last pair of resnet and attention blocks
+ # if i == len(blocks) - 1 and additional_residuals is not None:
+ # hidden_states = hidden_states + additional_residuals
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states, scale=lora_scale)
+
+ if down_block_add_samples is not None:
+ hidden_states = hidden_states + down_block_add_samples.pop(0)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnUpBlockMotion(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads: int = 1,
+ cross_attention_dim: int = 1280,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ attention_type: str = "default",
+ temporal_cross_attention_dim: Optional[int] = None,
+ temporal_num_attention_heads: int = 8,
+ temporal_max_seq_length: int = 32,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+ motion_modules = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ motion_modules.append(
+ TransformerTemporalModel(
+ num_attention_heads=temporal_num_attention_heads,
+ in_channels=out_channels,
+ norm_num_groups=resnet_groups,
+ cross_attention_dim=temporal_cross_attention_dim,
+ attention_bias=False,
+ activation_fn="geglu",
+ positional_embeddings="sinusoidal",
+ num_positional_embeddings=temporal_max_seq_length,
+ attention_head_dim=out_channels // temporal_num_attention_heads,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+ self.motion_modules = nn.ModuleList(motion_modules)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ upsample_size: Optional[int] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ num_frames: int = 1,
+ up_block_add_samples: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+
+ blocks = zip(self.resnets, self.attentions, self.motion_modules)
+ for resnet, attn, motion_module in blocks:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ # FreeU: Only operate on the first two stages
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ if up_block_add_samples is not None:
+ hidden_states = hidden_states + up_block_add_samples.pop(0)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(motion_module),
+ hidden_states.requires_grad_(),
+ temb,
+ num_frames,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ if up_block_add_samples is not None:
+ hidden_states = hidden_states + up_block_add_samples.pop(0)
+ hidden_states = motion_module(
+ hidden_states,
+ num_frames=num_frames,
+ )
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
+ if up_block_add_samples is not None:
+ hidden_states = hidden_states + up_block_add_samples.pop(0)
+
+ return hidden_states
+
+
+class UpBlockMotion(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor: float = 1.0,
+ add_upsample: bool = True,
+ temporal_norm_num_groups: int = 32,
+ temporal_cross_attention_dim: Optional[int] = None,
+ temporal_num_attention_heads: int = 8,
+ temporal_max_seq_length: int = 32,
+ ):
+ super().__init__()
+ resnets = []
+ motion_modules = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ motion_modules.append(
+ TransformerTemporalModel(
+ num_attention_heads=temporal_num_attention_heads,
+ in_channels=out_channels,
+ norm_num_groups=temporal_norm_num_groups,
+ cross_attention_dim=temporal_cross_attention_dim,
+ attention_bias=False,
+ activation_fn="geglu",
+ positional_embeddings="sinusoidal",
+ num_positional_embeddings=temporal_max_seq_length,
+ attention_head_dim=out_channels // temporal_num_attention_heads,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+ self.motion_modules = nn.ModuleList(motion_modules)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ upsample_size=None,
+ scale: float = 1.0,
+ num_frames: int = 1,
+ up_block_add_samples: Optional[torch.FloatTensor] = None,
+ ) -> torch.FloatTensor:
+ is_freeu_enabled = (
+ getattr(self, "s1", None)
+ and getattr(self, "s2", None)
+ and getattr(self, "b1", None)
+ and getattr(self, "b2", None)
+ )
+
+ blocks = zip(self.resnets, self.motion_modules)
+
+ for resnet, motion_module in blocks:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ # FreeU: Only operate on the first two stages
+ if is_freeu_enabled:
+ hidden_states, res_hidden_states = apply_freeu(
+ self.resolution_idx,
+ hidden_states,
+ res_hidden_states,
+ s1=self.s1,
+ s2=self.s2,
+ b1=self.b1,
+ b2=self.b2,
+ )
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ use_reentrant=False,
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet), hidden_states, temb
+ )
+
+ if up_block_add_samples is not None:
+ hidden_states = hidden_states + up_block_add_samples.pop(0)
+
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(motion_module),
+ hidden_states.requires_grad_(),
+ temb,
+ num_frames,
+ use_reentrant=False,
+ )
+ else:
+ hidden_states = resnet(hidden_states, temb, scale=scale)
+ if up_block_add_samples is not None:
+ hidden_states = hidden_states + up_block_add_samples.pop(0)
+ hidden_states = motion_module(hidden_states, num_frames=num_frames)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
+
+ if up_block_add_samples is not None:
+ hidden_states = hidden_states + up_block_add_samples.pop(0)
+
+ return hidden_states
+
+
+class UNetMidBlockCrossAttnMotion(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ transformer_layers_per_block: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ num_attention_heads: int = 1,
+ output_scale_factor: float = 1.0,
+ cross_attention_dim: int = 1280,
+ dual_cross_attention: float = False,
+ use_linear_projection: float = False,
+ upcast_attention: float = False,
+ attention_type: str = "default",
+ temporal_num_attention_heads: int = 1,
+ temporal_cross_attention_dim: Optional[int] = None,
+ temporal_max_seq_length: int = 32,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+ motion_modules = []
+
+ for _ in range(num_layers):
+ if not dual_cross_attention:
+ attentions.append(
+ Transformer2DModel(
+ num_attention_heads,
+ in_channels // num_attention_heads,
+ in_channels=in_channels,
+ num_layers=transformer_layers_per_block,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ attention_type=attention_type,
+ )
+ )
+ else:
+ attentions.append(
+ DualTransformer2DModel(
+ num_attention_heads,
+ in_channels // num_attention_heads,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ )
+ )
+ resnets.append(
+ ResnetBlock2D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ motion_modules.append(
+ TransformerTemporalModel(
+ num_attention_heads=temporal_num_attention_heads,
+ attention_head_dim=in_channels // temporal_num_attention_heads,
+ in_channels=in_channels,
+ norm_num_groups=resnet_groups,
+ cross_attention_dim=temporal_cross_attention_dim,
+ attention_bias=False,
+ positional_embeddings="sinusoidal",
+ num_positional_embeddings=temporal_max_seq_length,
+ activation_fn="geglu",
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+ self.motion_modules = nn.ModuleList(motion_modules)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.FloatTensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
+ num_frames: int = 1,
+ mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
+ ) -> torch.FloatTensor:
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
+ hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
+
+ blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
+ for attn, resnet, motion_module in blocks:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ ##########
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+ if mid_block_add_sample is not None:
+ hidden_states = hidden_states + mid_block_add_sample
+ ################################################################
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(motion_module),
+ hidden_states.requires_grad_(),
+ temb,
+ num_frames,
+ **ckpt_kwargs,
+ )
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ cross_attention_kwargs=cross_attention_kwargs,
+ attention_mask=attention_mask,
+ encoder_attention_mask=encoder_attention_mask,
+ return_dict=False,
+ )[0]
+ ##########
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+ if mid_block_add_sample is not None:
+ hidden_states = hidden_states + mid_block_add_sample
+ ################################################################
+ hidden_states = motion_module(
+ hidden_states,
+ num_frames=num_frames,
+ )
+ hidden_states = resnet(hidden_states, temb, scale=lora_scale)
+
+ return hidden_states
+
+
+class MidBlockTemporalDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ attention_head_dim: int = 512,
+ num_layers: int = 1,
+ upcast_attention: bool = False,
+ ):
+ super().__init__()
+
+ resnets = []
+ attentions = []
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ SpatioTemporalResBlock(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=1e-6,
+ temporal_eps=1e-5,
+ merge_factor=0.0,
+ merge_strategy="learned",
+ switch_spatial_to_temporal_mix=True,
+ )
+ )
+
+ attentions.append(
+ Attention(
+ query_dim=in_channels,
+ heads=in_channels // attention_head_dim,
+ dim_head=attention_head_dim,
+ eps=1e-6,
+ upcast_attention=upcast_attention,
+ norm_num_groups=32,
+ bias=True,
+ residual_connection=True,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ image_only_indicator: torch.FloatTensor,
+ ):
+ hidden_states = self.resnets[0](
+ hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+ for resnet, attn in zip(self.resnets[1:], self.attentions):
+ hidden_states = attn(hidden_states)
+ hidden_states = resnet(
+ hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+
+ return hidden_states
+
+
+class UpBlockTemporalDecoder(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ num_layers: int = 1,
+ add_upsample: bool = True,
+ ):
+ super().__init__()
+ resnets = []
+ for i in range(num_layers):
+ input_channels = in_channels if i == 0 else out_channels
+
+ resnets.append(
+ SpatioTemporalResBlock(
+ in_channels=input_channels,
+ out_channels=out_channels,
+ temb_channels=None,
+ eps=1e-6,
+ temporal_eps=1e-5,
+ merge_factor=0.0,
+ merge_strategy="learned",
+ switch_spatial_to_temporal_mix=True,
+ )
+ )
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ image_only_indicator: torch.FloatTensor,
+ ) -> torch.FloatTensor:
+ for resnet in self.resnets:
+ hidden_states = resnet(
+ hidden_states,
+ image_only_indicator=image_only_indicator,
+ )
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class UNetMidBlockSpatioTemporal(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ num_layers: int = 1,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ num_attention_heads: int = 1,
+ cross_attention_dim: int = 1280,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ # support for variable transformer layers per block
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+ # there is always at least one resnet
+ resnets = [
+ SpatioTemporalResBlock(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=1e-5,
+ )
+ ]
+ attentions = []
+
+ for i in range(num_layers):
+ attentions.append(
+ TransformerSpatioTemporalModel(
+ num_attention_heads,
+ in_channels // num_attention_heads,
+ in_channels=in_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ )
+ )
+
+ resnets.append(
+ SpatioTemporalResBlock(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=1e-5,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ hidden_states = self.resnets[0](
+ hidden_states,
+ temb,
+ image_only_indicator=image_only_indicator,
+ )
+
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ if self.training and self.gradient_checkpointing: # TODO
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ return_dict=False,
+ )[0]
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ image_only_indicator,
+ **ckpt_kwargs,
+ )
+ else:
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ return_dict=False,
+ )[0]
+ hidden_states = resnet(
+ hidden_states,
+ temb,
+ image_only_indicator=image_only_indicator,
+ )
+
+ return hidden_states
+
+
+class DownBlockSpatioTemporal(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ num_layers: int = 1,
+ add_downsample: bool = True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ SpatioTemporalResBlock(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=1e-5,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ name="op",
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ output_states = ()
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ image_only_indicator,
+ use_reentrant=False,
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ image_only_indicator,
+ )
+ else:
+ hidden_states = resnet(
+ hidden_states,
+ temb,
+ image_only_indicator=image_only_indicator,
+ )
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnDownBlockSpatioTemporal(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ num_layers: int = 1,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ num_attention_heads: int = 1,
+ cross_attention_dim: int = 1280,
+ add_downsample: bool = True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ SpatioTemporalResBlock(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=1e-6,
+ )
+ )
+ attentions.append(
+ TransformerSpatioTemporalModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample2D(
+ out_channels,
+ use_conv=True,
+ out_channels=out_channels,
+ padding=1,
+ name="op",
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
+ output_states = ()
+
+ blocks = list(zip(self.resnets, self.attentions))
+ for resnet, attn in blocks:
+ if self.training and self.gradient_checkpointing: # TODO
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ image_only_indicator,
+ **ckpt_kwargs,
+ )
+
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ return_dict=False,
+ )[0]
+ else:
+ hidden_states = resnet(
+ hidden_states,
+ temb,
+ image_only_indicator=image_only_indicator,
+ )
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ return_dict=False,
+ )[0]
+
+ output_states = output_states + (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states = output_states + (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class UpBlockSpatioTemporal(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ add_upsample: bool = True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ SpatioTemporalResBlock(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ if is_torch_version(">=", "1.11.0"):
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ image_only_indicator,
+ use_reentrant=False,
+ )
+ else:
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ image_only_indicator,
+ )
+ else:
+ hidden_states = resnet(
+ hidden_states,
+ temb,
+ image_only_indicator=image_only_indicator,
+ )
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
+
+
+class CrossAttnUpBlockSpatioTemporal(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ resolution_idx: Optional[int] = None,
+ num_layers: int = 1,
+ transformer_layers_per_block: Union[int, Tuple[int]] = 1,
+ resnet_eps: float = 1e-6,
+ num_attention_heads: int = 1,
+ cross_attention_dim: int = 1280,
+ add_upsample: bool = True,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.num_attention_heads = num_attention_heads
+
+ if isinstance(transformer_layers_per_block, int):
+ transformer_layers_per_block = [transformer_layers_per_block] * num_layers
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ SpatioTemporalResBlock(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ )
+ )
+ attentions.append(
+ TransformerSpatioTemporalModel(
+ num_attention_heads,
+ out_channels // num_attention_heads,
+ in_channels=out_channels,
+ num_layers=transformer_layers_per_block[i],
+ cross_attention_dim=cross_attention_dim,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+ self.resolution_idx = resolution_idx
+
+ def forward(
+ self,
+ hidden_states: torch.FloatTensor,
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
+ temb: Optional[torch.FloatTensor] = None,
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
+ image_only_indicator: Optional[torch.Tensor] = None,
+ ) -> torch.FloatTensor:
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing: # TODO
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(resnet),
+ hidden_states,
+ temb,
+ image_only_indicator,
+ **ckpt_kwargs,
+ )
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ return_dict=False,
+ )[0]
+ else:
+ hidden_states = resnet(
+ hidden_states,
+ temb,
+ image_only_indicator=image_only_indicator,
+ )
+ hidden_states = attn(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ image_only_indicator=image_only_indicator,
+ return_dict=False,
+ )[0]
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states)
+
+ return hidden_states
diff --git a/libs/unet_motion_model.py b/libs/unet_motion_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8cddd585169350fdea85f243ac395a985ca0904
--- /dev/null
+++ b/libs/unet_motion_model.py
@@ -0,0 +1,975 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from typing import Any, Dict, Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.loaders import UNet2DConditionLoadersMixin
+from diffusers.utils import logging, deprecate
+from diffusers.models.attention_processor import (
+ ADDED_KV_ATTENTION_PROCESSORS,
+ CROSS_ATTENTION_PROCESSORS,
+ AttentionProcessor,
+ AttnAddedKVProcessor,
+ AttnProcessor,
+)
+# from diffusers.models.controlnet import ControlNetConditioningEmbedding
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.transformers.transformer_temporal import TransformerTemporalModel
+from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2DCrossAttn
+from .unet_2d_condition import UNet2DConditionModel
+from .unet_3d_blocks import (
+ CrossAttnDownBlockMotion,
+ CrossAttnUpBlockMotion,
+ DownBlockMotion,
+ UNetMidBlockCrossAttnMotion,
+ UpBlockMotion,
+ get_down_block,
+ get_up_block,
+)
+from diffusers.models.unets.unet_3d_condition import UNet3DConditionOutput
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+class MotionModules(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ layers_per_block: int = 2,
+ num_attention_heads: int = 8,
+ attention_bias: bool = False,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ norm_num_groups: int = 32,
+ max_seq_length: int = 32,
+ ):
+ super().__init__()
+ self.motion_modules = nn.ModuleList([])
+
+ for i in range(layers_per_block):
+ self.motion_modules.append(
+ TransformerTemporalModel(
+ in_channels=in_channels,
+ norm_num_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ attention_bias=attention_bias,
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=in_channels // num_attention_heads,
+ positional_embeddings="sinusoidal",
+ num_positional_embeddings=max_seq_length,
+ )
+ )
+
+
+class MotionAdapter(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
+ motion_layers_per_block: int = 2,
+ motion_mid_block_layers_per_block: int = 1,
+ motion_num_attention_heads: int = 8,
+ motion_norm_num_groups: int = 32,
+ motion_max_seq_length: int = 32,
+ use_motion_mid_block: bool = True,
+ ):
+ """Container to store AnimateDiff Motion Modules
+
+ Args:
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
+ The tuple of output channels for each UNet block.
+ motion_layers_per_block (`int`, *optional*, defaults to 2):
+ The number of motion layers per UNet block.
+ motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1):
+ The number of motion layers in the middle UNet block.
+ motion_num_attention_heads (`int`, *optional*, defaults to 8):
+ The number of heads to use in each attention layer of the motion module.
+ motion_norm_num_groups (`int`, *optional*, defaults to 32):
+ The number of groups to use in each group normalization layer of the motion module.
+ motion_max_seq_length (`int`, *optional*, defaults to 32):
+ The maximum sequence length to use in the motion module.
+ use_motion_mid_block (`bool`, *optional*, defaults to True):
+ Whether to use a motion module in the middle of the UNet.
+ """
+
+ super().__init__()
+ down_blocks = []
+ up_blocks = []
+
+ for i, channel in enumerate(block_out_channels):
+ output_channel = block_out_channels[i]
+ down_blocks.append(
+ MotionModules(
+ in_channels=output_channel,
+ norm_num_groups=motion_norm_num_groups,
+ cross_attention_dim=None,
+ activation_fn="geglu",
+ attention_bias=False,
+ num_attention_heads=motion_num_attention_heads,
+ max_seq_length=motion_max_seq_length,
+ layers_per_block=motion_layers_per_block,
+ )
+ )
+
+ if use_motion_mid_block:
+ self.mid_block = MotionModules(
+ in_channels=block_out_channels[-1],
+ norm_num_groups=motion_norm_num_groups,
+ cross_attention_dim=None,
+ activation_fn="geglu",
+ attention_bias=False,
+ num_attention_heads=motion_num_attention_heads,
+ layers_per_block=motion_mid_block_layers_per_block,
+ max_seq_length=motion_max_seq_length,
+ )
+ else:
+ self.mid_block = None
+
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ output_channel = reversed_block_out_channels[0]
+ for i, channel in enumerate(reversed_block_out_channels):
+ output_channel = reversed_block_out_channels[i]
+ up_blocks.append(
+ MotionModules(
+ in_channels=output_channel,
+ norm_num_groups=motion_norm_num_groups,
+ cross_attention_dim=None,
+ activation_fn="geglu",
+ attention_bias=False,
+ num_attention_heads=motion_num_attention_heads,
+ max_seq_length=motion_max_seq_length,
+ layers_per_block=motion_layers_per_block + 1,
+ )
+ )
+
+ self.down_blocks = nn.ModuleList(down_blocks)
+ self.up_blocks = nn.ModuleList(up_blocks)
+
+ def forward(self, sample):
+ pass
+
+
+class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
+ r"""
+ A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a
+ sample shaped output.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
+ for all models (such as downloading or saving).
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None,
+ in_channels: int = 4,
+ conditioning_channels: int = 3,
+ out_channels: int = 4,
+ down_block_types: Tuple[str, ...] = (
+ "CrossAttnDownBlockMotion",
+ "CrossAttnDownBlockMotion",
+ "CrossAttnDownBlockMotion",
+ "DownBlockMotion",
+ ),
+ mid_block_type: Optional[str] = "UNetMidBlockCrossAttnMotion",
+ up_block_types: Tuple[str, ...] = (
+ "UpBlockMotion",
+ "CrossAttnUpBlockMotion",
+ "CrossAttnUpBlockMotion",
+ "CrossAttnUpBlockMotion",
+ ),
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ use_linear_projection: bool = False,
+ num_attention_heads: Union[int, Tuple[int, ...]] = 8,
+ motion_max_seq_length: int = 32,
+ motion_num_attention_heads: int = 8,
+ use_motion_mid_block: int = True,
+ encoder_hid_dim: Optional[int] = None,
+ encoder_hid_dim_type: Optional[str] = None,
+ conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256),
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+
+ # Check inputs
+ if len(down_block_types) != len(up_block_types):
+ raise ValueError(
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
+ )
+
+ if len(block_out_channels) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
+ )
+
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
+ raise ValueError(
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
+ )
+
+ # input
+ conv_in_kernel = 3
+ conv_out_kernel = 3
+ conv_in_padding = (conv_in_kernel - 1) // 2
+ self.conv_in = nn.Conv2d(
+ in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding
+ )
+
+ # time
+ time_embed_dim = block_out_channels[0] * 4
+ self.time_proj = Timesteps(block_out_channels[0], True, 0)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(
+ timestep_input_dim,
+ time_embed_dim,
+ act_fn=act_fn,
+ )
+
+ if encoder_hid_dim_type is None:
+ self.encoder_hid_proj = None
+
+ # control net conditioning embedding
+ # self.controlnet_cond_embedding = ControlNetConditioningEmbedding(
+ # conditioning_embedding_channels=block_out_channels[0],
+ # block_out_channels=conditioning_embedding_out_channels,
+ # conditioning_channels=conditioning_channels,
+ # )
+
+ # class embedding
+ self.down_blocks = nn.ModuleList([])
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(num_attention_heads, int):
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[i],
+ downsample_padding=downsample_padding,
+ use_linear_projection=use_linear_projection,
+ dual_cross_attention=False,
+ temporal_num_attention_heads=motion_num_attention_heads,
+ temporal_max_seq_length=motion_max_seq_length,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ if use_motion_mid_block:
+ self.mid_block = UNetMidBlockCrossAttnMotion(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=False,
+ temporal_num_attention_heads=motion_num_attention_heads,
+ temporal_max_seq_length=motion_max_seq_length,
+ )
+
+ else:
+ self.mid_block = UNetMidBlock2DCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=num_attention_heads[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=False,
+ )
+
+ # count how many layers upsample the images
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
+
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ num_attention_heads=reversed_num_attention_heads[i],
+ dual_cross_attention=False,
+ resolution_idx=i,
+ use_linear_projection=use_linear_projection,
+ temporal_num_attention_heads=motion_num_attention_heads,
+ temporal_max_seq_length=motion_max_seq_length,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ if norm_num_groups is not None:
+ self.conv_norm_out = nn.GroupNorm(
+ num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
+ )
+ self.conv_act = nn.SiLU()
+ else:
+ self.conv_norm_out = None
+ self.conv_act = None
+
+ conv_out_padding = (conv_out_kernel - 1) // 2
+ self.conv_out = nn.Conv2d(
+ block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
+ )
+
+ @classmethod
+ def from_unet2d(
+ cls,
+ unet: UNet2DConditionModel,
+ motion_adapter: Optional[MotionAdapter] = None,
+ load_weights: bool = True,
+ ):
+ has_motion_adapter = motion_adapter is not None
+
+ # based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459
+ config = unet.config
+ config["_class_name"] = cls.__name__
+
+ down_blocks = []
+ for down_blocks_type in config["down_block_types"]:
+ if "CrossAttn" in down_blocks_type:
+ down_blocks.append("CrossAttnDownBlockMotion")
+ else:
+ down_blocks.append("DownBlockMotion")
+ config["down_block_types"] = down_blocks
+
+ up_blocks = []
+ for down_blocks_type in config["up_block_types"]:
+ if "CrossAttn" in down_blocks_type:
+ up_blocks.append("CrossAttnUpBlockMotion")
+ else:
+ up_blocks.append("UpBlockMotion")
+
+ config["up_block_types"] = up_blocks
+
+ if has_motion_adapter:
+ config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"]
+ config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"]
+ config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"]
+
+ # Need this for backwards compatibility with UNet2DConditionModel checkpoints
+ if not config.get("num_attention_heads"):
+ config["num_attention_heads"] = config["attention_head_dim"]
+
+ model = cls.from_config(config)
+
+ if not load_weights:
+ return model
+
+ model.conv_in.load_state_dict(unet.conv_in.state_dict())
+ model.time_proj.load_state_dict(unet.time_proj.state_dict())
+ model.time_embedding.load_state_dict(unet.time_embedding.state_dict())
+ # model.controlnet_cond_embedding.load_state_dict(unet.controlnet_cond_embedding.state_dict()) # pose guider
+
+ for i, down_block in enumerate(unet.down_blocks):
+ model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict())
+ if hasattr(model.down_blocks[i], "attentions"):
+ model.down_blocks[i].attentions.load_state_dict(down_block.attentions.state_dict())
+ if model.down_blocks[i].downsamplers:
+ model.down_blocks[i].downsamplers.load_state_dict(down_block.downsamplers.state_dict())
+
+ for i, up_block in enumerate(unet.up_blocks):
+ model.up_blocks[i].resnets.load_state_dict(up_block.resnets.state_dict())
+ if hasattr(model.up_blocks[i], "attentions"):
+ model.up_blocks[i].attentions.load_state_dict(up_block.attentions.state_dict())
+ if model.up_blocks[i].upsamplers:
+ model.up_blocks[i].upsamplers.load_state_dict(up_block.upsamplers.state_dict())
+
+ model.mid_block.resnets.load_state_dict(unet.mid_block.resnets.state_dict())
+ model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict())
+
+ if unet.conv_norm_out is not None:
+ model.conv_norm_out.load_state_dict(unet.conv_norm_out.state_dict())
+ if unet.conv_act is not None:
+ model.conv_act.load_state_dict(unet.conv_act.state_dict())
+ model.conv_out.load_state_dict(unet.conv_out.state_dict())
+
+ if has_motion_adapter:
+ model.load_motion_modules(motion_adapter)
+
+ # ensure that the Motion UNet is the same dtype as the UNet2DConditionModel
+ model.to(unet.dtype)
+
+ return model
+
+ def freeze_unet2d_params(self) -> None:
+ """Freeze the weights of just the UNet2DConditionModel, and leave the motion modules
+ unfrozen for fine tuning.
+ """
+ # Freeze everything
+ for param in self.parameters():
+ param.requires_grad = False
+
+ # Unfreeze Motion Modules
+ for down_block in self.down_blocks:
+ motion_modules = down_block.motion_modules
+ for param in motion_modules.parameters():
+ param.requires_grad = True
+
+ for up_block in self.up_blocks:
+ motion_modules = up_block.motion_modules
+ for param in motion_modules.parameters():
+ param.requires_grad = True
+
+ if hasattr(self.mid_block, "motion_modules"):
+ motion_modules = self.mid_block.motion_modules
+ for param in motion_modules.parameters():
+ param.requires_grad = True
+
+ def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]) -> None:
+ for i, down_block in enumerate(motion_adapter.down_blocks):
+ self.down_blocks[i].motion_modules.load_state_dict(down_block.motion_modules.state_dict())
+ for i, up_block in enumerate(motion_adapter.up_blocks):
+ self.up_blocks[i].motion_modules.load_state_dict(up_block.motion_modules.state_dict())
+
+ # to support older motion modules that don't have a mid_block
+ if hasattr(self.mid_block, "motion_modules"):
+ self.mid_block.motion_modules.load_state_dict(motion_adapter.mid_block.motion_modules.state_dict())
+
+ def save_motion_modules(
+ self,
+ save_directory: str,
+ is_main_process: bool = True,
+ safe_serialization: bool = True,
+ variant: Optional[str] = None,
+ push_to_hub: bool = False,
+ **kwargs,
+ ) -> None:
+ state_dict = self.state_dict()
+
+ # Extract all motion modules
+ motion_state_dict = {}
+ for k, v in state_dict.items():
+ if "motion_modules" in k:
+ motion_state_dict[k] = v
+
+ adapter = MotionAdapter(
+ block_out_channels=self.config["block_out_channels"],
+ motion_layers_per_block=self.config["layers_per_block"],
+ motion_norm_num_groups=self.config["norm_num_groups"],
+ motion_num_attention_heads=self.config["motion_num_attention_heads"],
+ motion_max_seq_length=self.config["motion_max_seq_length"],
+ use_motion_mid_block=self.config["use_motion_mid_block"],
+ )
+ adapter.load_state_dict(motion_state_dict)
+ adapter.save_pretrained(
+ save_directory=save_directory,
+ is_main_process=is_main_process,
+ safe_serialization=safe_serialization,
+ variant=variant,
+ push_to_hub=push_to_hub,
+ **kwargs,
+ )
+
+ @property
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
+ r"""
+ Returns:
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
+ indexed by its weight name.
+ """
+ # set recursively
+ processors = {}
+
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
+ if hasattr(module, "get_processor"):
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
+
+ return processors
+
+ for name, module in self.named_children():
+ fn_recursive_add_processors(name, module, processors)
+
+ return processors
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor
+ def set_attn_processor(
+ self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False
+ ):
+ r"""
+ Sets the attention processor to use to compute attention.
+
+ Parameters:
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
+ for **all** `Attention` layers.
+
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
+ processor. This is strongly recommended when setting trainable attention processors.
+
+ """
+ count = len(self.attn_processors.keys())
+
+ if isinstance(processor, dict) and len(processor) != count:
+ raise ValueError(
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
+ )
+
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
+ if hasattr(module, "set_processor"):
+ if not isinstance(processor, dict):
+ module.set_processor(processor, _remove_lora=_remove_lora)
+ else:
+ module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora)
+
+ for sub_name, child in module.named_children():
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
+
+ for name, module in self.named_children():
+ fn_recursive_attn_processor(name, module, processor)
+
+ # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
+ """
+ Sets the attention processor to use [feed forward
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
+
+ Parameters:
+ chunk_size (`int`, *optional*):
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
+ over each tensor of dim=`dim`.
+ dim (`int`, *optional*, defaults to `0`):
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
+ or dim=1 (sequence length).
+ """
+ if dim not in [0, 1]:
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
+
+ # By default chunk size is 1
+ chunk_size = chunk_size or 1
+
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, chunk_size, dim)
+
+ # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking
+ def disable_forward_chunking(self) -> None:
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
+ if hasattr(module, "set_chunk_feed_forward"):
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
+
+ for child in module.children():
+ fn_recursive_feed_forward(child, chunk_size, dim)
+
+ for module in self.children():
+ fn_recursive_feed_forward(module, None, 0)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
+ def set_default_attn_processor(self) -> None:
+ """
+ Disables custom attention processors and sets the default attention implementation.
+ """
+ if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnAddedKVProcessor()
+ elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
+ processor = AttnProcessor()
+ else:
+ raise ValueError(
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
+ )
+
+ self.set_attn_processor(processor, _remove_lora=True)
+
+ def _set_gradient_checkpointing(self, module, value: bool = False) -> None:
+ if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)):
+ module.gradient_checkpointing = value
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu
+ def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
+ r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
+
+ The suffixes after the scaling factors represent the stage blocks where they are being applied.
+
+ Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that
+ are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL.
+
+ Args:
+ s1 (`float`):
+ Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ s2 (`float`):
+ Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to
+ mitigate the "oversmoothing effect" in the enhanced denoising process.
+ b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features.
+ b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features.
+ """
+ for i, upsample_block in enumerate(self.up_blocks):
+ setattr(upsample_block, "s1", s1)
+ setattr(upsample_block, "s2", s2)
+ setattr(upsample_block, "b1", b1)
+ setattr(upsample_block, "b2", b2)
+
+ # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu
+ def disable_freeu(self) -> None:
+ """Disables the FreeU mechanism."""
+ freeu_keys = {"s1", "s2", "b1", "b2"}
+ for i, upsample_block in enumerate(self.up_blocks):
+ for k in freeu_keys:
+ if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
+ setattr(upsample_block, k, None)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor,
+ # controlnet_cond: torch.FloatTensor,
+ timestep_cond: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ num_frames: int = 24,
+ down_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
+ mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None,
+ up_block_add_samples: Optional[Tuple[torch.Tensor]] = None,
+ ) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]:
+ r"""
+ The [`UNetMotionModel`] forward method.
+
+ Args:
+ sample (`torch.FloatTensor`):
+ The noisy input tensor with the following shape `(batch * num_frames, channel, height, width`.
+ timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
+ encoder_hidden_states (`torch.FloatTensor`):
+ The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
+ timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`):
+ Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed
+ through the `self.time_embedding` layer to obtain the timestep embeddings.
+ attention_mask (`torch.Tensor`, *optional*, defaults to `None`):
+ An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask
+ is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large
+ negative values to the attention scores corresponding to "discard" tokens.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
+ `self.processor` in
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
+ down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*):
+ A tuple of tensors that if specified are added to the residuals of down unet blocks.
+ mid_block_additional_residual: (`torch.Tensor`, *optional*):
+ A tensor that if specified is added to the residual of the middle unet block.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`:
+ If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise
+ a `tuple` is returned where the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # 1. time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0] // num_frames)
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.dtype)
+
+ emb = self.time_embedding(t_emb, timestep_cond)
+ emb = emb.repeat_interleave(repeats=num_frames, dim=0)
+
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
+ if "image_embeds" not in added_cond_kwargs:
+ raise ValueError(
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
+ )
+ image_embeds = added_cond_kwargs.get("image_embeds")
+ image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
+ encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
+
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
+
+ # 2. pre-process
+ # sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])
+ # N*T C H W
+ sample = self.conv_in(sample)
+ # controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
+ # sample += controlnet_cond
+
+ # 3. down
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
+ # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets
+ is_adapter = down_intrablock_additional_residuals is not None
+ # maintain backward compatibility for legacy usage, where
+ # T2I-Adapter and ControlNet both use down_block_additional_residuals arg
+ # but can only use one or the other
+ is_brushnet = down_block_add_samples is not None and mid_block_add_sample is not None and up_block_add_samples is not None
+ if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None:
+ deprecate(
+ "T2I should not use down_block_additional_residuals",
+ "1.3.0",
+ "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \
+ and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \
+ for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ",
+ standard_warn=False,
+ )
+ down_intrablock_additional_residuals = down_block_additional_residuals
+ is_adapter = True
+
+ down_block_res_samples = (sample,)
+ if is_brushnet:
+ sample = sample + down_block_add_samples.pop(0)
+
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ # For t2i-adapter CrossAttnDownBlock2D
+ additional_residuals = {}
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0)
+ if is_brushnet and len(down_block_add_samples)>0:
+ additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
+ for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))]
+
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ num_frames=num_frames,
+ cross_attention_kwargs=cross_attention_kwargs,
+ **additional_residuals,
+ )
+ else:
+ additional_residuals = {}
+ if is_brushnet and len(down_block_add_samples)>0:
+ additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0)
+ for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))]
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames, **additional_residuals,)
+ if is_adapter and len(down_intrablock_additional_residuals) > 0:
+ sample += down_intrablock_additional_residuals.pop(0)
+
+ down_block_res_samples += res_samples
+
+ if is_controlnet:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ if down_block_additional_residuals is not None:
+ new_down_block_res_samples = ()
+
+ for down_block_res_sample, down_block_additional_residual in zip(
+ down_block_res_samples, down_block_additional_residuals
+ ):
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
+ new_down_block_res_samples += (down_block_res_sample,)
+
+ down_block_res_samples = new_down_block_res_samples
+
+ # 4. mid
+ if self.mid_block is not None:
+ # To support older versions of motion modules that don't have a mid_block
+ if hasattr(self.mid_block, "motion_modules"):
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ num_frames=num_frames,
+ cross_attention_kwargs=cross_attention_kwargs,
+ mid_block_add_sample=mid_block_add_sample,
+ )
+ else:
+ sample = self.mid_block(
+ sample,
+ emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ cross_attention_kwargs=cross_attention_kwargs,
+ mid_block_add_sample=mid_block_add_sample,
+ )
+
+ if is_controlnet:
+ sample = sample + mid_block_additional_residual
+
+ # if is_brushnet:
+ # sample = sample + mid_block_add_sample
+
+ if mid_block_additional_residual is not None:
+ sample = sample + mid_block_additional_residual
+
+ # 5. up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ additional_residuals = {}
+ if is_brushnet and len(up_block_add_samples)>0:
+ additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
+ for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))]
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ num_frames=num_frames,
+ cross_attention_kwargs=cross_attention_kwargs,
+ **additional_residuals,
+ )
+ else:
+ additional_residuals = {}
+ if is_brushnet and len(up_block_add_samples)>0:
+ additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0)
+ for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))]
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ upsample_size=upsample_size,
+ num_frames=num_frames,
+ **additional_residuals,
+ )
+
+ # 6. post-process
+ if self.conv_norm_out:
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+
+ sample = self.conv_out(sample)
+
+ # reshape to (batch, framerate, channel, width, height)
+ # sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:])
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet3DConditionOutput(sample=sample)
diff --git a/propainter/RAFT/__init__.py b/propainter/RAFT/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7179ea3ce4ad81425c619772d4bc47bc7ceea3a
--- /dev/null
+++ b/propainter/RAFT/__init__.py
@@ -0,0 +1,2 @@
+# from .demo import RAFT_infer
+from .raft import RAFT
diff --git a/propainter/RAFT/corr.py b/propainter/RAFT/corr.py
new file mode 100644
index 0000000000000000000000000000000000000000..449dbd963b8303eda242a65063ca857b95475721
--- /dev/null
+++ b/propainter/RAFT/corr.py
@@ -0,0 +1,111 @@
+import torch
+import torch.nn.functional as F
+from .utils.utils import bilinear_sampler, coords_grid
+
+try:
+ import alt_cuda_corr
+except:
+ # alt_cuda_corr is not compiled
+ pass
+
+
+class CorrBlock:
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
+ self.num_levels = num_levels
+ self.radius = radius
+ self.corr_pyramid = []
+
+ # all pairs correlation
+ corr = CorrBlock.corr(fmap1, fmap2)
+
+ batch, h1, w1, dim, h2, w2 = corr.shape
+ corr = corr.reshape(batch*h1*w1, dim, h2, w2)
+
+ self.corr_pyramid.append(corr)
+ for i in range(self.num_levels-1):
+ corr = F.avg_pool2d(corr, 2, stride=2)
+ self.corr_pyramid.append(corr)
+
+ def __call__(self, coords):
+ r = self.radius
+ coords = coords.permute(0, 2, 3, 1)
+ batch, h1, w1, _ = coords.shape
+
+ out_pyramid = []
+ for i in range(self.num_levels):
+ corr = self.corr_pyramid[i]
+ dx = torch.linspace(-r, r, 2*r+1)
+ dy = torch.linspace(-r, r, 2*r+1)
+ delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device)
+
+ centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i
+ delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2)
+ coords_lvl = centroid_lvl + delta_lvl
+
+ corr = bilinear_sampler(corr, coords_lvl)
+ corr = corr.view(batch, h1, w1, -1)
+ out_pyramid.append(corr)
+
+ out = torch.cat(out_pyramid, dim=-1)
+ return out.permute(0, 3, 1, 2).contiguous().float()
+
+ @staticmethod
+ def corr(fmap1, fmap2):
+ batch, dim, ht, wd = fmap1.shape
+ fmap1 = fmap1.view(batch, dim, ht*wd)
+ fmap2 = fmap2.view(batch, dim, ht*wd)
+
+ corr = torch.matmul(fmap1.transpose(1,2), fmap2)
+ corr = corr.view(batch, ht, wd, 1, ht, wd)
+ return corr / torch.sqrt(torch.tensor(dim).float())
+
+
+class CorrLayer(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, fmap1, fmap2, coords, r):
+ fmap1 = fmap1.contiguous()
+ fmap2 = fmap2.contiguous()
+ coords = coords.contiguous()
+ ctx.save_for_backward(fmap1, fmap2, coords)
+ ctx.r = r
+ corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r)
+ return corr
+
+ @staticmethod
+ def backward(ctx, grad_corr):
+ fmap1, fmap2, coords = ctx.saved_tensors
+ grad_corr = grad_corr.contiguous()
+ fmap1_grad, fmap2_grad, coords_grad = \
+ correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r)
+ return fmap1_grad, fmap2_grad, coords_grad, None
+
+
+class AlternateCorrBlock:
+ def __init__(self, fmap1, fmap2, num_levels=4, radius=4):
+ self.num_levels = num_levels
+ self.radius = radius
+
+ self.pyramid = [(fmap1, fmap2)]
+ for i in range(self.num_levels):
+ fmap1 = F.avg_pool2d(fmap1, 2, stride=2)
+ fmap2 = F.avg_pool2d(fmap2, 2, stride=2)
+ self.pyramid.append((fmap1, fmap2))
+
+ def __call__(self, coords):
+
+ coords = coords.permute(0, 2, 3, 1)
+ B, H, W, _ = coords.shape
+
+ corr_list = []
+ for i in range(self.num_levels):
+ r = self.radius
+ fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1)
+ fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1)
+
+ coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous()
+ corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r)
+ corr_list.append(corr.squeeze(1))
+
+ corr = torch.stack(corr_list, dim=1)
+ corr = corr.reshape(B, -1, H, W)
+ return corr / 16.0
diff --git a/propainter/RAFT/datasets.py b/propainter/RAFT/datasets.py
new file mode 100644
index 0000000000000000000000000000000000000000..3411fdacfb900024005e8997d07c600e963a95ca
--- /dev/null
+++ b/propainter/RAFT/datasets.py
@@ -0,0 +1,235 @@
+# Data loading based on https://github.com/NVIDIA/flownet2-pytorch
+
+import numpy as np
+import torch
+import torch.utils.data as data
+import torch.nn.functional as F
+
+import os
+import math
+import random
+from glob import glob
+import os.path as osp
+
+from utils import frame_utils
+from utils.augmentor import FlowAugmentor, SparseFlowAugmentor
+
+
+class FlowDataset(data.Dataset):
+ def __init__(self, aug_params=None, sparse=False):
+ self.augmentor = None
+ self.sparse = sparse
+ if aug_params is not None:
+ if sparse:
+ self.augmentor = SparseFlowAugmentor(**aug_params)
+ else:
+ self.augmentor = FlowAugmentor(**aug_params)
+
+ self.is_test = False
+ self.init_seed = False
+ self.flow_list = []
+ self.image_list = []
+ self.extra_info = []
+
+ def __getitem__(self, index):
+
+ if self.is_test:
+ img1 = frame_utils.read_gen(self.image_list[index][0])
+ img2 = frame_utils.read_gen(self.image_list[index][1])
+ img1 = np.array(img1).astype(np.uint8)[..., :3]
+ img2 = np.array(img2).astype(np.uint8)[..., :3]
+ img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
+ img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
+ return img1, img2, self.extra_info[index]
+
+ if not self.init_seed:
+ worker_info = torch.utils.data.get_worker_info()
+ if worker_info is not None:
+ torch.manual_seed(worker_info.id)
+ np.random.seed(worker_info.id)
+ random.seed(worker_info.id)
+ self.init_seed = True
+
+ index = index % len(self.image_list)
+ valid = None
+ if self.sparse:
+ flow, valid = frame_utils.readFlowKITTI(self.flow_list[index])
+ else:
+ flow = frame_utils.read_gen(self.flow_list[index])
+
+ img1 = frame_utils.read_gen(self.image_list[index][0])
+ img2 = frame_utils.read_gen(self.image_list[index][1])
+
+ flow = np.array(flow).astype(np.float32)
+ img1 = np.array(img1).astype(np.uint8)
+ img2 = np.array(img2).astype(np.uint8)
+
+ # grayscale images
+ if len(img1.shape) == 2:
+ img1 = np.tile(img1[...,None], (1, 1, 3))
+ img2 = np.tile(img2[...,None], (1, 1, 3))
+ else:
+ img1 = img1[..., :3]
+ img2 = img2[..., :3]
+
+ if self.augmentor is not None:
+ if self.sparse:
+ img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid)
+ else:
+ img1, img2, flow = self.augmentor(img1, img2, flow)
+
+ img1 = torch.from_numpy(img1).permute(2, 0, 1).float()
+ img2 = torch.from_numpy(img2).permute(2, 0, 1).float()
+ flow = torch.from_numpy(flow).permute(2, 0, 1).float()
+
+ if valid is not None:
+ valid = torch.from_numpy(valid)
+ else:
+ valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000)
+
+ return img1, img2, flow, valid.float()
+
+
+ def __rmul__(self, v):
+ self.flow_list = v * self.flow_list
+ self.image_list = v * self.image_list
+ return self
+
+ def __len__(self):
+ return len(self.image_list)
+
+
+class MpiSintel(FlowDataset):
+ def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'):
+ super(MpiSintel, self).__init__(aug_params)
+ flow_root = osp.join(root, split, 'flow')
+ image_root = osp.join(root, split, dstype)
+
+ if split == 'test':
+ self.is_test = True
+
+ for scene in os.listdir(image_root):
+ image_list = sorted(glob(osp.join(image_root, scene, '*.png')))
+ for i in range(len(image_list)-1):
+ self.image_list += [ [image_list[i], image_list[i+1]] ]
+ self.extra_info += [ (scene, i) ] # scene and frame_id
+
+ if split != 'test':
+ self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo')))
+
+
+class FlyingChairs(FlowDataset):
+ def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'):
+ super(FlyingChairs, self).__init__(aug_params)
+
+ images = sorted(glob(osp.join(root, '*.ppm')))
+ flows = sorted(glob(osp.join(root, '*.flo')))
+ assert (len(images)//2 == len(flows))
+
+ split_list = np.loadtxt('chairs_split.txt', dtype=np.int32)
+ for i in range(len(flows)):
+ xid = split_list[i]
+ if (split=='training' and xid==1) or (split=='validation' and xid==2):
+ self.flow_list += [ flows[i] ]
+ self.image_list += [ [images[2*i], images[2*i+1]] ]
+
+
+class FlyingThings3D(FlowDataset):
+ def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'):
+ super(FlyingThings3D, self).__init__(aug_params)
+
+ for cam in ['left']:
+ for direction in ['into_future', 'into_past']:
+ image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*')))
+ image_dirs = sorted([osp.join(f, cam) for f in image_dirs])
+
+ flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*')))
+ flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs])
+
+ for idir, fdir in zip(image_dirs, flow_dirs):
+ images = sorted(glob(osp.join(idir, '*.png')) )
+ flows = sorted(glob(osp.join(fdir, '*.pfm')) )
+ for i in range(len(flows)-1):
+ if direction == 'into_future':
+ self.image_list += [ [images[i], images[i+1]] ]
+ self.flow_list += [ flows[i] ]
+ elif direction == 'into_past':
+ self.image_list += [ [images[i+1], images[i]] ]
+ self.flow_list += [ flows[i+1] ]
+
+
+class KITTI(FlowDataset):
+ def __init__(self, aug_params=None, split='training', root='datasets/KITTI'):
+ super(KITTI, self).__init__(aug_params, sparse=True)
+ if split == 'testing':
+ self.is_test = True
+
+ root = osp.join(root, split)
+ images1 = sorted(glob(osp.join(root, 'image_2/*_10.png')))
+ images2 = sorted(glob(osp.join(root, 'image_2/*_11.png')))
+
+ for img1, img2 in zip(images1, images2):
+ frame_id = img1.split('/')[-1]
+ self.extra_info += [ [frame_id] ]
+ self.image_list += [ [img1, img2] ]
+
+ if split == 'training':
+ self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png')))
+
+
+class HD1K(FlowDataset):
+ def __init__(self, aug_params=None, root='datasets/HD1k'):
+ super(HD1K, self).__init__(aug_params, sparse=True)
+
+ seq_ix = 0
+ while 1:
+ flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix)))
+ images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix)))
+
+ if len(flows) == 0:
+ break
+
+ for i in range(len(flows)-1):
+ self.flow_list += [flows[i]]
+ self.image_list += [ [images[i], images[i+1]] ]
+
+ seq_ix += 1
+
+
+def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'):
+ """ Create the data loader for the corresponding trainign set """
+
+ if args.stage == 'chairs':
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True}
+ train_dataset = FlyingChairs(aug_params, split='training')
+
+ elif args.stage == 'things':
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True}
+ clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass')
+ final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass')
+ train_dataset = clean_dataset + final_dataset
+
+ elif args.stage == 'sintel':
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True}
+ things = FlyingThings3D(aug_params, dstype='frames_cleanpass')
+ sintel_clean = MpiSintel(aug_params, split='training', dstype='clean')
+ sintel_final = MpiSintel(aug_params, split='training', dstype='final')
+
+ if TRAIN_DS == 'C+T+K+S+H':
+ kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True})
+ hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True})
+ train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things
+
+ elif TRAIN_DS == 'C+T+K/S':
+ train_dataset = 100*sintel_clean + 100*sintel_final + things
+
+ elif args.stage == 'kitti':
+ aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False}
+ train_dataset = KITTI(aug_params, split='training')
+
+ train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size,
+ pin_memory=False, shuffle=True, num_workers=4, drop_last=True)
+
+ print('Training with %d image pairs' % len(train_dataset))
+ return train_loader
+
diff --git a/propainter/RAFT/demo.py b/propainter/RAFT/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..096963bdbb36aed3df673f131d6e044d8c6f95ea
--- /dev/null
+++ b/propainter/RAFT/demo.py
@@ -0,0 +1,79 @@
+import sys
+import argparse
+import os
+import cv2
+import glob
+import numpy as np
+import torch
+from PIL import Image
+
+from .raft import RAFT
+from .utils import flow_viz
+from .utils.utils import InputPadder
+
+
+
+DEVICE = 'cuda'
+
+def load_image(imfile):
+ img = np.array(Image.open(imfile)).astype(np.uint8)
+ img = torch.from_numpy(img).permute(2, 0, 1).float()
+ return img
+
+
+def load_image_list(image_files):
+ images = []
+ for imfile in sorted(image_files):
+ images.append(load_image(imfile))
+
+ images = torch.stack(images, dim=0)
+ images = images.to(DEVICE)
+
+ padder = InputPadder(images.shape)
+ return padder.pad(images)[0]
+
+
+def viz(img, flo):
+ img = img[0].permute(1,2,0).cpu().numpy()
+ flo = flo[0].permute(1,2,0).cpu().numpy()
+
+ # map flow to rgb image
+ flo = flow_viz.flow_to_image(flo)
+ # img_flo = np.concatenate([img, flo], axis=0)
+ img_flo = flo
+
+ cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]])
+ # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0)
+ # cv2.waitKey()
+
+
+def demo(args):
+ model = torch.nn.DataParallel(RAFT(args))
+ model.load_state_dict(torch.load(args.model))
+
+ model = model.module
+ model.to(DEVICE)
+ model.eval()
+
+ with torch.no_grad():
+ images = glob.glob(os.path.join(args.path, '*.png')) + \
+ glob.glob(os.path.join(args.path, '*.jpg'))
+
+ images = load_image_list(images)
+ for i in range(images.shape[0]-1):
+ image1 = images[i,None]
+ image2 = images[i+1,None]
+
+ flow_low, flow_up = model(image1, image2, iters=20, test_mode=True)
+ viz(image1, flow_up)
+
+
+def RAFT_infer(args):
+ model = torch.nn.DataParallel(RAFT(args))
+ model.load_state_dict(torch.load(args.model))
+
+ model = model.module
+ model.to(DEVICE)
+ model.eval()
+
+ return model
diff --git a/propainter/RAFT/extractor.py b/propainter/RAFT/extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a9c759d1243d4694e8656c2f6f8a37e53edd009
--- /dev/null
+++ b/propainter/RAFT/extractor.py
@@ -0,0 +1,267 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class ResidualBlock(nn.Module):
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
+ super(ResidualBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride)
+ self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1)
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == 'group':
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not stride == 1:
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == 'batch':
+ self.norm1 = nn.BatchNorm2d(planes)
+ self.norm2 = nn.BatchNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == 'instance':
+ self.norm1 = nn.InstanceNorm2d(planes)
+ self.norm2 = nn.InstanceNorm2d(planes)
+ if not stride == 1:
+ self.norm3 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == 'none':
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ if not stride == 1:
+ self.norm3 = nn.Sequential()
+
+ if stride == 1:
+ self.downsample = None
+
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3)
+
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.norm1(self.conv1(y)))
+ y = self.relu(self.norm2(self.conv2(y)))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x+y)
+
+
+
+class BottleneckBlock(nn.Module):
+ def __init__(self, in_planes, planes, norm_fn='group', stride=1):
+ super(BottleneckBlock, self).__init__()
+
+ self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0)
+ self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride)
+ self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0)
+ self.relu = nn.ReLU(inplace=True)
+
+ num_groups = planes // 8
+
+ if norm_fn == 'group':
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4)
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+ if not stride == 1:
+ self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
+
+ elif norm_fn == 'batch':
+ self.norm1 = nn.BatchNorm2d(planes//4)
+ self.norm2 = nn.BatchNorm2d(planes//4)
+ self.norm3 = nn.BatchNorm2d(planes)
+ if not stride == 1:
+ self.norm4 = nn.BatchNorm2d(planes)
+
+ elif norm_fn == 'instance':
+ self.norm1 = nn.InstanceNorm2d(planes//4)
+ self.norm2 = nn.InstanceNorm2d(planes//4)
+ self.norm3 = nn.InstanceNorm2d(planes)
+ if not stride == 1:
+ self.norm4 = nn.InstanceNorm2d(planes)
+
+ elif norm_fn == 'none':
+ self.norm1 = nn.Sequential()
+ self.norm2 = nn.Sequential()
+ self.norm3 = nn.Sequential()
+ if not stride == 1:
+ self.norm4 = nn.Sequential()
+
+ if stride == 1:
+ self.downsample = None
+
+ else:
+ self.downsample = nn.Sequential(
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4)
+
+
+ def forward(self, x):
+ y = x
+ y = self.relu(self.norm1(self.conv1(y)))
+ y = self.relu(self.norm2(self.conv2(y)))
+ y = self.relu(self.norm3(self.conv3(y)))
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+
+ return self.relu(x+y)
+
+class BasicEncoder(nn.Module):
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
+ super(BasicEncoder, self).__init__()
+ self.norm_fn = norm_fn
+
+ if self.norm_fn == 'group':
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64)
+
+ elif self.norm_fn == 'batch':
+ self.norm1 = nn.BatchNorm2d(64)
+
+ elif self.norm_fn == 'instance':
+ self.norm1 = nn.InstanceNorm2d(64)
+
+ elif self.norm_fn == 'none':
+ self.norm1 = nn.Sequential()
+
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
+ self.relu1 = nn.ReLU(inplace=True)
+
+ self.in_planes = 64
+ self.layer1 = self._make_layer(64, stride=1)
+ self.layer2 = self._make_layer(96, stride=2)
+ self.layer3 = self._make_layer(128, stride=2)
+
+ # output convolution
+ self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1)
+
+ self.dropout = None
+ if dropout > 0:
+ self.dropout = nn.Dropout2d(p=dropout)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, dim, stride=1):
+ layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride)
+ layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+
+ def forward(self, x):
+
+ # if input is list, combine batch dimension
+ is_list = isinstance(x, tuple) or isinstance(x, list)
+ if is_list:
+ batch_dim = x[0].shape[0]
+ x = torch.cat(x, dim=0)
+
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu1(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+
+ x = self.conv2(x)
+
+ if self.training and self.dropout is not None:
+ x = self.dropout(x)
+
+ if is_list:
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
+
+ return x
+
+
+class SmallEncoder(nn.Module):
+ def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0):
+ super(SmallEncoder, self).__init__()
+ self.norm_fn = norm_fn
+
+ if self.norm_fn == 'group':
+ self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32)
+
+ elif self.norm_fn == 'batch':
+ self.norm1 = nn.BatchNorm2d(32)
+
+ elif self.norm_fn == 'instance':
+ self.norm1 = nn.InstanceNorm2d(32)
+
+ elif self.norm_fn == 'none':
+ self.norm1 = nn.Sequential()
+
+ self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3)
+ self.relu1 = nn.ReLU(inplace=True)
+
+ self.in_planes = 32
+ self.layer1 = self._make_layer(32, stride=1)
+ self.layer2 = self._make_layer(64, stride=2)
+ self.layer3 = self._make_layer(96, stride=2)
+
+ self.dropout = None
+ if dropout > 0:
+ self.dropout = nn.Dropout2d(p=dropout)
+
+ self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1)
+
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)):
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def _make_layer(self, dim, stride=1):
+ layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride)
+ layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1)
+ layers = (layer1, layer2)
+
+ self.in_planes = dim
+ return nn.Sequential(*layers)
+
+
+ def forward(self, x):
+
+ # if input is list, combine batch dimension
+ is_list = isinstance(x, tuple) or isinstance(x, list)
+ if is_list:
+ batch_dim = x[0].shape[0]
+ x = torch.cat(x, dim=0)
+
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu1(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.conv2(x)
+
+ if self.training and self.dropout is not None:
+ x = self.dropout(x)
+
+ if is_list:
+ x = torch.split(x, [batch_dim, batch_dim], dim=0)
+
+ return x
diff --git a/propainter/RAFT/raft.py b/propainter/RAFT/raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..829ef97b8d3e280aac59ebef7bb2eaf06274b62a
--- /dev/null
+++ b/propainter/RAFT/raft.py
@@ -0,0 +1,146 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .update import BasicUpdateBlock, SmallUpdateBlock
+from .extractor import BasicEncoder, SmallEncoder
+from .corr import CorrBlock, AlternateCorrBlock
+from .utils.utils import bilinear_sampler, coords_grid, upflow8
+
+try:
+ autocast = torch.cuda.amp.autocast
+except:
+ # dummy autocast for PyTorch < 1.6
+ class autocast:
+ def __init__(self, enabled):
+ pass
+ def __enter__(self):
+ pass
+ def __exit__(self, *args):
+ pass
+
+
+class RAFT(nn.Module):
+ def __init__(self, args):
+ super(RAFT, self).__init__()
+ self.args = args
+
+ if args.small:
+ self.hidden_dim = hdim = 96
+ self.context_dim = cdim = 64
+ args.corr_levels = 4
+ args.corr_radius = 3
+
+ else:
+ self.hidden_dim = hdim = 128
+ self.context_dim = cdim = 128
+ args.corr_levels = 4
+ args.corr_radius = 4
+
+ if 'dropout' not in args._get_kwargs():
+ args.dropout = 0
+
+ if 'alternate_corr' not in args._get_kwargs():
+ args.alternate_corr = False
+
+ # feature network, context network, and update block
+ if args.small:
+ self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout)
+ self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout)
+ self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim)
+
+ else:
+ self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout)
+ self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout)
+ self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim)
+
+
+ def freeze_bn(self):
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
+
+ def initialize_flow(self, img):
+ """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0"""
+ N, C, H, W = img.shape
+ coords0 = coords_grid(N, H//8, W//8).to(img.device)
+ coords1 = coords_grid(N, H//8, W//8).to(img.device)
+
+ # optical flow computed as difference: flow = coords1 - coords0
+ return coords0, coords1
+
+ def upsample_flow(self, flow, mask):
+ """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """
+ N, _, H, W = flow.shape
+ mask = mask.view(N, 1, 9, 8, 8, H, W)
+ mask = torch.softmax(mask, dim=2)
+
+ up_flow = F.unfold(8 * flow, [3,3], padding=1)
+ up_flow = up_flow.view(N, 2, 9, 1, 1, H, W)
+
+ up_flow = torch.sum(mask * up_flow, dim=2)
+ up_flow = up_flow.permute(0, 1, 4, 2, 5, 3)
+ return up_flow.reshape(N, 2, 8*H, 8*W)
+
+
+ def forward(self, image1, image2, iters=12, flow_init=None, test_mode=True):
+ """ Estimate optical flow between pair of frames """
+
+ # image1 = 2 * (image1 / 255.0) - 1.0
+ # image2 = 2 * (image2 / 255.0) - 1.0
+
+ image1 = image1.contiguous()
+ image2 = image2.contiguous()
+
+ hdim = self.hidden_dim
+ cdim = self.context_dim
+
+ # run the feature network
+ with autocast(enabled=self.args.mixed_precision):
+ fmap1, fmap2 = self.fnet([image1, image2])
+
+ fmap1 = fmap1.float()
+ fmap2 = fmap2.float()
+
+ if self.args.alternate_corr:
+ corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
+ else:
+ corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius)
+
+ # run the context network
+ with autocast(enabled=self.args.mixed_precision):
+ cnet = self.cnet(image1)
+ net, inp = torch.split(cnet, [hdim, cdim], dim=1)
+ net = torch.tanh(net)
+ inp = torch.relu(inp)
+
+ coords0, coords1 = self.initialize_flow(image1)
+
+ if flow_init is not None:
+ coords1 = coords1 + flow_init
+
+ flow_predictions = []
+ for itr in range(iters):
+ coords1 = coords1.detach()
+ corr = corr_fn(coords1) # index correlation volume
+
+ flow = coords1 - coords0
+ with autocast(enabled=self.args.mixed_precision):
+ net, up_mask, delta_flow = self.update_block(net, inp, corr, flow)
+
+ # F(t+1) = F(t) + \Delta(t)
+ coords1 = coords1 + delta_flow
+
+ # upsample predictions
+ if up_mask is None:
+ flow_up = upflow8(coords1 - coords0)
+ else:
+ flow_up = self.upsample_flow(coords1 - coords0, up_mask)
+
+ flow_predictions.append(flow_up)
+
+ if test_mode:
+ return coords1 - coords0, flow_up
+
+ return flow_predictions
diff --git a/propainter/RAFT/update.py b/propainter/RAFT/update.py
new file mode 100644
index 0000000000000000000000000000000000000000..f940497f9b5eb1c12091574fe9a0223a1b196d50
--- /dev/null
+++ b/propainter/RAFT/update.py
@@ -0,0 +1,139 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class FlowHead(nn.Module):
+ def __init__(self, input_dim=128, hidden_dim=256):
+ super(FlowHead, self).__init__()
+ self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1)
+ self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ return self.conv2(self.relu(self.conv1(x)))
+
+class ConvGRU(nn.Module):
+ def __init__(self, hidden_dim=128, input_dim=192+128):
+ super(ConvGRU, self).__init__()
+ self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
+ self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
+ self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1)
+
+ def forward(self, h, x):
+ hx = torch.cat([h, x], dim=1)
+
+ z = torch.sigmoid(self.convz(hx))
+ r = torch.sigmoid(self.convr(hx))
+ q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1)))
+
+ h = (1-z) * h + z * q
+ return h
+
+class SepConvGRU(nn.Module):
+ def __init__(self, hidden_dim=128, input_dim=192+128):
+ super(SepConvGRU, self).__init__()
+ self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
+ self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
+ self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2))
+
+ self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
+ self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
+ self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0))
+
+
+ def forward(self, h, x):
+ # horizontal
+ hx = torch.cat([h, x], dim=1)
+ z = torch.sigmoid(self.convz1(hx))
+ r = torch.sigmoid(self.convr1(hx))
+ q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1)))
+ h = (1-z) * h + z * q
+
+ # vertical
+ hx = torch.cat([h, x], dim=1)
+ z = torch.sigmoid(self.convz2(hx))
+ r = torch.sigmoid(self.convr2(hx))
+ q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1)))
+ h = (1-z) * h + z * q
+
+ return h
+
+class SmallMotionEncoder(nn.Module):
+ def __init__(self, args):
+ super(SmallMotionEncoder, self).__init__()
+ cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
+ self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0)
+ self.convf1 = nn.Conv2d(2, 64, 7, padding=3)
+ self.convf2 = nn.Conv2d(64, 32, 3, padding=1)
+ self.conv = nn.Conv2d(128, 80, 3, padding=1)
+
+ def forward(self, flow, corr):
+ cor = F.relu(self.convc1(corr))
+ flo = F.relu(self.convf1(flow))
+ flo = F.relu(self.convf2(flo))
+ cor_flo = torch.cat([cor, flo], dim=1)
+ out = F.relu(self.conv(cor_flo))
+ return torch.cat([out, flow], dim=1)
+
+class BasicMotionEncoder(nn.Module):
+ def __init__(self, args):
+ super(BasicMotionEncoder, self).__init__()
+ cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2
+ self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0)
+ self.convc2 = nn.Conv2d(256, 192, 3, padding=1)
+ self.convf1 = nn.Conv2d(2, 128, 7, padding=3)
+ self.convf2 = nn.Conv2d(128, 64, 3, padding=1)
+ self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1)
+
+ def forward(self, flow, corr):
+ cor = F.relu(self.convc1(corr))
+ cor = F.relu(self.convc2(cor))
+ flo = F.relu(self.convf1(flow))
+ flo = F.relu(self.convf2(flo))
+
+ cor_flo = torch.cat([cor, flo], dim=1)
+ out = F.relu(self.conv(cor_flo))
+ return torch.cat([out, flow], dim=1)
+
+class SmallUpdateBlock(nn.Module):
+ def __init__(self, args, hidden_dim=96):
+ super(SmallUpdateBlock, self).__init__()
+ self.encoder = SmallMotionEncoder(args)
+ self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64)
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=128)
+
+ def forward(self, net, inp, corr, flow):
+ motion_features = self.encoder(flow, corr)
+ inp = torch.cat([inp, motion_features], dim=1)
+ net = self.gru(net, inp)
+ delta_flow = self.flow_head(net)
+
+ return net, None, delta_flow
+
+class BasicUpdateBlock(nn.Module):
+ def __init__(self, args, hidden_dim=128, input_dim=128):
+ super(BasicUpdateBlock, self).__init__()
+ self.args = args
+ self.encoder = BasicMotionEncoder(args)
+ self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim)
+ self.flow_head = FlowHead(hidden_dim, hidden_dim=256)
+
+ self.mask = nn.Sequential(
+ nn.Conv2d(128, 256, 3, padding=1),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(256, 64*9, 1, padding=0))
+
+ def forward(self, net, inp, corr, flow, upsample=True):
+ motion_features = self.encoder(flow, corr)
+ inp = torch.cat([inp, motion_features], dim=1)
+
+ net = self.gru(net, inp)
+ delta_flow = self.flow_head(net)
+
+ # scale mask to balence gradients
+ mask = .25 * self.mask(net)
+ return net, mask, delta_flow
+
+
+
diff --git a/propainter/RAFT/utils/__init__.py b/propainter/RAFT/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0437149bfee42718973728158641020ccc1906ad
--- /dev/null
+++ b/propainter/RAFT/utils/__init__.py
@@ -0,0 +1,2 @@
+from .flow_viz import flow_to_image
+from .frame_utils import writeFlow
diff --git a/propainter/RAFT/utils/augmentor.py b/propainter/RAFT/utils/augmentor.py
new file mode 100644
index 0000000000000000000000000000000000000000..e81c4f2b5c16c31c0ae236d744f299d430228a04
--- /dev/null
+++ b/propainter/RAFT/utils/augmentor.py
@@ -0,0 +1,246 @@
+import numpy as np
+import random
+import math
+from PIL import Image
+
+import cv2
+cv2.setNumThreads(0)
+cv2.ocl.setUseOpenCL(False)
+
+import torch
+from torchvision.transforms import ColorJitter
+import torch.nn.functional as F
+
+
+class FlowAugmentor:
+ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True):
+
+ # spatial augmentation params
+ self.crop_size = crop_size
+ self.min_scale = min_scale
+ self.max_scale = max_scale
+ self.spatial_aug_prob = 0.8
+ self.stretch_prob = 0.8
+ self.max_stretch = 0.2
+
+ # flip augmentation params
+ self.do_flip = do_flip
+ self.h_flip_prob = 0.5
+ self.v_flip_prob = 0.1
+
+ # photometric augmentation params
+ self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14)
+ self.asymmetric_color_aug_prob = 0.2
+ self.eraser_aug_prob = 0.5
+
+ def color_transform(self, img1, img2):
+ """ Photometric augmentation """
+
+ # asymmetric
+ if np.random.rand() < self.asymmetric_color_aug_prob:
+ img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8)
+ img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8)
+
+ # symmetric
+ else:
+ image_stack = np.concatenate([img1, img2], axis=0)
+ image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
+ img1, img2 = np.split(image_stack, 2, axis=0)
+
+ return img1, img2
+
+ def eraser_transform(self, img1, img2, bounds=[50, 100]):
+ """ Occlusion augmentation """
+
+ ht, wd = img1.shape[:2]
+ if np.random.rand() < self.eraser_aug_prob:
+ mean_color = np.mean(img2.reshape(-1, 3), axis=0)
+ for _ in range(np.random.randint(1, 3)):
+ x0 = np.random.randint(0, wd)
+ y0 = np.random.randint(0, ht)
+ dx = np.random.randint(bounds[0], bounds[1])
+ dy = np.random.randint(bounds[0], bounds[1])
+ img2[y0:y0+dy, x0:x0+dx, :] = mean_color
+
+ return img1, img2
+
+ def spatial_transform(self, img1, img2, flow):
+ # randomly sample scale
+ ht, wd = img1.shape[:2]
+ min_scale = np.maximum(
+ (self.crop_size[0] + 8) / float(ht),
+ (self.crop_size[1] + 8) / float(wd))
+
+ scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
+ scale_x = scale
+ scale_y = scale
+ if np.random.rand() < self.stretch_prob:
+ scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
+ scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch)
+
+ scale_x = np.clip(scale_x, min_scale, None)
+ scale_y = np.clip(scale_y, min_scale, None)
+
+ if np.random.rand() < self.spatial_aug_prob:
+ # rescale the images
+ img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ flow = flow * [scale_x, scale_y]
+
+ if self.do_flip:
+ if np.random.rand() < self.h_flip_prob: # h-flip
+ img1 = img1[:, ::-1]
+ img2 = img2[:, ::-1]
+ flow = flow[:, ::-1] * [-1.0, 1.0]
+
+ if np.random.rand() < self.v_flip_prob: # v-flip
+ img1 = img1[::-1, :]
+ img2 = img2[::-1, :]
+ flow = flow[::-1, :] * [1.0, -1.0]
+
+ y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0])
+ x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1])
+
+ img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+ img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+ flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+
+ return img1, img2, flow
+
+ def __call__(self, img1, img2, flow):
+ img1, img2 = self.color_transform(img1, img2)
+ img1, img2 = self.eraser_transform(img1, img2)
+ img1, img2, flow = self.spatial_transform(img1, img2, flow)
+
+ img1 = np.ascontiguousarray(img1)
+ img2 = np.ascontiguousarray(img2)
+ flow = np.ascontiguousarray(flow)
+
+ return img1, img2, flow
+
+class SparseFlowAugmentor:
+ def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False):
+ # spatial augmentation params
+ self.crop_size = crop_size
+ self.min_scale = min_scale
+ self.max_scale = max_scale
+ self.spatial_aug_prob = 0.8
+ self.stretch_prob = 0.8
+ self.max_stretch = 0.2
+
+ # flip augmentation params
+ self.do_flip = do_flip
+ self.h_flip_prob = 0.5
+ self.v_flip_prob = 0.1
+
+ # photometric augmentation params
+ self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14)
+ self.asymmetric_color_aug_prob = 0.2
+ self.eraser_aug_prob = 0.5
+
+ def color_transform(self, img1, img2):
+ image_stack = np.concatenate([img1, img2], axis=0)
+ image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8)
+ img1, img2 = np.split(image_stack, 2, axis=0)
+ return img1, img2
+
+ def eraser_transform(self, img1, img2):
+ ht, wd = img1.shape[:2]
+ if np.random.rand() < self.eraser_aug_prob:
+ mean_color = np.mean(img2.reshape(-1, 3), axis=0)
+ for _ in range(np.random.randint(1, 3)):
+ x0 = np.random.randint(0, wd)
+ y0 = np.random.randint(0, ht)
+ dx = np.random.randint(50, 100)
+ dy = np.random.randint(50, 100)
+ img2[y0:y0+dy, x0:x0+dx, :] = mean_color
+
+ return img1, img2
+
+ def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0):
+ ht, wd = flow.shape[:2]
+ coords = np.meshgrid(np.arange(wd), np.arange(ht))
+ coords = np.stack(coords, axis=-1)
+
+ coords = coords.reshape(-1, 2).astype(np.float32)
+ flow = flow.reshape(-1, 2).astype(np.float32)
+ valid = valid.reshape(-1).astype(np.float32)
+
+ coords0 = coords[valid>=1]
+ flow0 = flow[valid>=1]
+
+ ht1 = int(round(ht * fy))
+ wd1 = int(round(wd * fx))
+
+ coords1 = coords0 * [fx, fy]
+ flow1 = flow0 * [fx, fy]
+
+ xx = np.round(coords1[:,0]).astype(np.int32)
+ yy = np.round(coords1[:,1]).astype(np.int32)
+
+ v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1)
+ xx = xx[v]
+ yy = yy[v]
+ flow1 = flow1[v]
+
+ flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32)
+ valid_img = np.zeros([ht1, wd1], dtype=np.int32)
+
+ flow_img[yy, xx] = flow1
+ valid_img[yy, xx] = 1
+
+ return flow_img, valid_img
+
+ def spatial_transform(self, img1, img2, flow, valid):
+ # randomly sample scale
+
+ ht, wd = img1.shape[:2]
+ min_scale = np.maximum(
+ (self.crop_size[0] + 1) / float(ht),
+ (self.crop_size[1] + 1) / float(wd))
+
+ scale = 2 ** np.random.uniform(self.min_scale, self.max_scale)
+ scale_x = np.clip(scale, min_scale, None)
+ scale_y = np.clip(scale, min_scale, None)
+
+ if np.random.rand() < self.spatial_aug_prob:
+ # rescale the images
+ img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR)
+ flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y)
+
+ if self.do_flip:
+ if np.random.rand() < 0.5: # h-flip
+ img1 = img1[:, ::-1]
+ img2 = img2[:, ::-1]
+ flow = flow[:, ::-1] * [-1.0, 1.0]
+ valid = valid[:, ::-1]
+
+ margin_y = 20
+ margin_x = 50
+
+ y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y)
+ x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x)
+
+ y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0])
+ x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1])
+
+ img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+ img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+ flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+ valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]]
+ return img1, img2, flow, valid
+
+
+ def __call__(self, img1, img2, flow, valid):
+ img1, img2 = self.color_transform(img1, img2)
+ img1, img2 = self.eraser_transform(img1, img2)
+ img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid)
+
+ img1 = np.ascontiguousarray(img1)
+ img2 = np.ascontiguousarray(img2)
+ flow = np.ascontiguousarray(flow)
+ valid = np.ascontiguousarray(valid)
+
+ return img1, img2, flow, valid
diff --git a/propainter/RAFT/utils/flow_viz.py b/propainter/RAFT/utils/flow_viz.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcee65e89b91b07ee0496aeb4c7e7436abf99641
--- /dev/null
+++ b/propainter/RAFT/utils/flow_viz.py
@@ -0,0 +1,132 @@
+# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
+
+
+# MIT License
+#
+# Copyright (c) 2018 Tom Runia
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to conditions.
+#
+# Author: Tom Runia
+# Date Created: 2018-08-03
+
+import numpy as np
+
+def make_colorwheel():
+ """
+ Generates a color wheel for optical flow visualization as presented in:
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
+
+ Code follows the original C++ source code of Daniel Scharstein.
+ Code follows the the Matlab source code of Deqing Sun.
+
+ Returns:
+ np.ndarray: Color wheel
+ """
+
+ RY = 15
+ YG = 6
+ GC = 4
+ CB = 11
+ BM = 13
+ MR = 6
+
+ ncols = RY + YG + GC + CB + BM + MR
+ colorwheel = np.zeros((ncols, 3))
+ col = 0
+
+ # RY
+ colorwheel[0:RY, 0] = 255
+ colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY)
+ col = col+RY
+ # YG
+ colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG)
+ colorwheel[col:col+YG, 1] = 255
+ col = col+YG
+ # GC
+ colorwheel[col:col+GC, 1] = 255
+ colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC)
+ col = col+GC
+ # CB
+ colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB)
+ colorwheel[col:col+CB, 2] = 255
+ col = col+CB
+ # BM
+ colorwheel[col:col+BM, 2] = 255
+ colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM)
+ col = col+BM
+ # MR
+ colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR)
+ colorwheel[col:col+MR, 0] = 255
+ return colorwheel
+
+
+def flow_uv_to_colors(u, v, convert_to_bgr=False):
+ """
+ Applies the flow color wheel to (possibly clipped) flow components u and v.
+
+ According to the C++ source code of Daniel Scharstein
+ According to the Matlab source code of Deqing Sun
+
+ Args:
+ u (np.ndarray): Input horizontal flow of shape [H,W]
+ v (np.ndarray): Input vertical flow of shape [H,W]
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
+
+ Returns:
+ np.ndarray: Flow visualization image of shape [H,W,3]
+ """
+ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
+ colorwheel = make_colorwheel() # shape [55x3]
+ ncols = colorwheel.shape[0]
+ rad = np.sqrt(np.square(u) + np.square(v))
+ a = np.arctan2(-v, -u)/np.pi
+ fk = (a+1) / 2*(ncols-1)
+ k0 = np.floor(fk).astype(np.int32)
+ k1 = k0 + 1
+ k1[k1 == ncols] = 0
+ f = fk - k0
+ for i in range(colorwheel.shape[1]):
+ tmp = colorwheel[:,i]
+ col0 = tmp[k0] / 255.0
+ col1 = tmp[k1] / 255.0
+ col = (1-f)*col0 + f*col1
+ idx = (rad <= 1)
+ col[idx] = 1 - rad[idx] * (1-col[idx])
+ col[~idx] = col[~idx] * 0.75 # out of range
+ # Note the 2-i => BGR instead of RGB
+ ch_idx = 2-i if convert_to_bgr else i
+ flow_image[:,:,ch_idx] = np.floor(255 * col)
+ return flow_image
+
+
+def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
+ """
+ Expects a two dimensional flow image of shape.
+
+ Args:
+ flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
+ clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
+ convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
+
+ Returns:
+ np.ndarray: Flow visualization image of shape [H,W,3]
+ """
+ assert flow_uv.ndim == 3, 'input flow must have three dimensions'
+ assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
+ if clip_flow is not None:
+ flow_uv = np.clip(flow_uv, 0, clip_flow)
+ u = flow_uv[:,:,0]
+ v = flow_uv[:,:,1]
+ rad = np.sqrt(np.square(u) + np.square(v))
+ rad_max = np.max(rad)
+ epsilon = 1e-5
+ u = u / (rad_max + epsilon)
+ v = v / (rad_max + epsilon)
+ return flow_uv_to_colors(u, v, convert_to_bgr)
\ No newline at end of file
diff --git a/propainter/RAFT/utils/flow_viz_pt.py b/propainter/RAFT/utils/flow_viz_pt.py
new file mode 100644
index 0000000000000000000000000000000000000000..12e666a40fa49c11592e311b141aa2a522e567fd
--- /dev/null
+++ b/propainter/RAFT/utils/flow_viz_pt.py
@@ -0,0 +1,118 @@
+# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization
+import torch
+torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732
+
+@torch.no_grad()
+def flow_to_image(flow: torch.Tensor) -> torch.Tensor:
+
+ """
+ Converts a flow to an RGB image.
+
+ Args:
+ flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float.
+
+ Returns:
+ img (Tensor): Image Tensor of dtype uint8 where each color corresponds
+ to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input.
+ """
+
+ if flow.dtype != torch.float:
+ raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.")
+
+ orig_shape = flow.shape
+ if flow.ndim == 3:
+ flow = flow[None] # Add batch dim
+
+ if flow.ndim != 4 or flow.shape[1] != 2:
+ raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.")
+
+ max_norm = torch.sum(flow**2, dim=1).sqrt().max()
+ epsilon = torch.finfo((flow).dtype).eps
+ normalized_flow = flow / (max_norm + epsilon)
+ img = _normalized_flow_to_image(normalized_flow)
+
+ if len(orig_shape) == 3:
+ img = img[0] # Remove batch dim
+ return img
+
+@torch.no_grad()
+def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor:
+
+ """
+ Converts a batch of normalized flow to an RGB image.
+
+ Args:
+ normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W)
+ Returns:
+ img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8.
+ """
+
+ N, _, H, W = normalized_flow.shape
+ device = normalized_flow.device
+ flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device)
+ colorwheel = _make_colorwheel().to(device) # shape [55x3]
+ num_cols = colorwheel.shape[0]
+ norm = torch.sum(normalized_flow**2, dim=1).sqrt()
+ a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi
+ fk = (a + 1) / 2 * (num_cols - 1)
+ k0 = torch.floor(fk).to(torch.long)
+ k1 = k0 + 1
+ k1[k1 == num_cols] = 0
+ f = fk - k0
+
+ for c in range(colorwheel.shape[1]):
+ tmp = colorwheel[:, c]
+ col0 = tmp[k0] / 255.0
+ col1 = tmp[k1] / 255.0
+ col = (1 - f) * col0 + f * col1
+ col = 1 - norm * (1 - col)
+ flow_image[:, c, :, :] = torch.floor(255. * col)
+ return flow_image
+
+
+@torch.no_grad()
+def _make_colorwheel() -> torch.Tensor:
+ """
+ Generates a color wheel for optical flow visualization as presented in:
+ Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
+ URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf.
+
+ Returns:
+ colorwheel (Tensor[55, 3]): Colorwheel Tensor.
+ """
+
+ RY = 15
+ YG = 6
+ GC = 4
+ CB = 11
+ BM = 13
+ MR = 6
+
+ ncols = RY + YG + GC + CB + BM + MR
+ colorwheel = torch.zeros((ncols, 3))
+ col = 0
+
+ # RY
+ colorwheel[0:RY, 0] = 255
+ colorwheel[0:RY, 1] = torch.floor(255. * torch.arange(0., RY) / RY)
+ col = col + RY
+ # YG
+ colorwheel[col : col + YG, 0] = 255 - torch.floor(255. * torch.arange(0., YG) / YG)
+ colorwheel[col : col + YG, 1] = 255
+ col = col + YG
+ # GC
+ colorwheel[col : col + GC, 1] = 255
+ colorwheel[col : col + GC, 2] = torch.floor(255. * torch.arange(0., GC) / GC)
+ col = col + GC
+ # CB
+ colorwheel[col : col + CB, 1] = 255 - torch.floor(255. * torch.arange(CB) / CB)
+ colorwheel[col : col + CB, 2] = 255
+ col = col + CB
+ # BM
+ colorwheel[col : col + BM, 2] = 255
+ colorwheel[col : col + BM, 0] = torch.floor(255. * torch.arange(0., BM) / BM)
+ col = col + BM
+ # MR
+ colorwheel[col : col + MR, 2] = 255 - torch.floor(255. * torch.arange(MR) / MR)
+ colorwheel[col : col + MR, 0] = 255
+ return colorwheel
diff --git a/propainter/RAFT/utils/frame_utils.py b/propainter/RAFT/utils/frame_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c491135efaffc25bd61ec3ecde99d236f5deb12
--- /dev/null
+++ b/propainter/RAFT/utils/frame_utils.py
@@ -0,0 +1,137 @@
+import numpy as np
+from PIL import Image
+from os.path import *
+import re
+
+import cv2
+cv2.setNumThreads(0)
+cv2.ocl.setUseOpenCL(False)
+
+TAG_CHAR = np.array([202021.25], np.float32)
+
+def readFlow(fn):
+ """ Read .flo file in Middlebury format"""
+ # Code adapted from:
+ # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy
+
+ # WARNING: this will work on little-endian architectures (eg Intel x86) only!
+ # print 'fn = %s'%(fn)
+ with open(fn, 'rb') as f:
+ magic = np.fromfile(f, np.float32, count=1)
+ if 202021.25 != magic:
+ print('Magic number incorrect. Invalid .flo file')
+ return None
+ else:
+ w = np.fromfile(f, np.int32, count=1)
+ h = np.fromfile(f, np.int32, count=1)
+ # print 'Reading %d x %d flo file\n' % (w, h)
+ data = np.fromfile(f, np.float32, count=2*int(w)*int(h))
+ # Reshape data into 3D array (columns, rows, bands)
+ # The reshape here is for visualization, the original code is (w,h,2)
+ return np.resize(data, (int(h), int(w), 2))
+
+def readPFM(file):
+ file = open(file, 'rb')
+
+ color = None
+ width = None
+ height = None
+ scale = None
+ endian = None
+
+ header = file.readline().rstrip()
+ if header == b'PF':
+ color = True
+ elif header == b'Pf':
+ color = False
+ else:
+ raise Exception('Not a PFM file.')
+
+ dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline())
+ if dim_match:
+ width, height = map(int, dim_match.groups())
+ else:
+ raise Exception('Malformed PFM header.')
+
+ scale = float(file.readline().rstrip())
+ if scale < 0: # little-endian
+ endian = '<'
+ scale = -scale
+ else:
+ endian = '>' # big-endian
+
+ data = np.fromfile(file, endian + 'f')
+ shape = (height, width, 3) if color else (height, width)
+
+ data = np.reshape(data, shape)
+ data = np.flipud(data)
+ return data
+
+def writeFlow(filename,uv,v=None):
+ """ Write optical flow to file.
+
+ If v is None, uv is assumed to contain both u and v channels,
+ stacked in depth.
+ Original code by Deqing Sun, adapted from Daniel Scharstein.
+ """
+ nBands = 2
+
+ if v is None:
+ assert(uv.ndim == 3)
+ assert(uv.shape[2] == 2)
+ u = uv[:,:,0]
+ v = uv[:,:,1]
+ else:
+ u = uv
+
+ assert(u.shape == v.shape)
+ height,width = u.shape
+ f = open(filename,'wb')
+ # write the header
+ f.write(TAG_CHAR)
+ np.array(width).astype(np.int32).tofile(f)
+ np.array(height).astype(np.int32).tofile(f)
+ # arrange into matrix form
+ tmp = np.zeros((height, width*nBands))
+ tmp[:,np.arange(width)*2] = u
+ tmp[:,np.arange(width)*2 + 1] = v
+ tmp.astype(np.float32).tofile(f)
+ f.close()
+
+
+def readFlowKITTI(filename):
+ flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR)
+ flow = flow[:,:,::-1].astype(np.float32)
+ flow, valid = flow[:, :, :2], flow[:, :, 2]
+ flow = (flow - 2**15) / 64.0
+ return flow, valid
+
+def readDispKITTI(filename):
+ disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0
+ valid = disp > 0.0
+ flow = np.stack([-disp, np.zeros_like(disp)], -1)
+ return flow, valid
+
+
+def writeFlowKITTI(filename, uv):
+ uv = 64.0 * uv + 2**15
+ valid = np.ones([uv.shape[0], uv.shape[1], 1])
+ uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16)
+ cv2.imwrite(filename, uv[..., ::-1])
+
+
+def read_gen(file_name, pil=False):
+ ext = splitext(file_name)[-1]
+ if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg':
+ return Image.open(file_name)
+ elif ext == '.bin' or ext == '.raw':
+ return np.load(file_name)
+ elif ext == '.flo':
+ return readFlow(file_name).astype(np.float32)
+ elif ext == '.pfm':
+ flow = readPFM(file_name).astype(np.float32)
+ if len(flow.shape) == 2:
+ return flow
+ else:
+ return flow[:, :, :-1]
+ return []
\ No newline at end of file
diff --git a/propainter/RAFT/utils/utils.py b/propainter/RAFT/utils/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f32d281c1c46353a0a2bf36b0550adb74125c65
--- /dev/null
+++ b/propainter/RAFT/utils/utils.py
@@ -0,0 +1,82 @@
+import torch
+import torch.nn.functional as F
+import numpy as np
+from scipy import interpolate
+
+
+class InputPadder:
+ """ Pads images such that dimensions are divisible by 8 """
+ def __init__(self, dims, mode='sintel'):
+ self.ht, self.wd = dims[-2:]
+ pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
+ pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
+ if mode == 'sintel':
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
+ else:
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
+
+ def pad(self, *inputs):
+ return [F.pad(x, self._pad, mode='replicate') for x in inputs]
+
+ def unpad(self,x):
+ ht, wd = x.shape[-2:]
+ c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
+ return x[..., c[0]:c[1], c[2]:c[3]]
+
+def forward_interpolate(flow):
+ flow = flow.detach().cpu().numpy()
+ dx, dy = flow[0], flow[1]
+
+ ht, wd = dx.shape
+ x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht))
+
+ x1 = x0 + dx
+ y1 = y0 + dy
+
+ x1 = x1.reshape(-1)
+ y1 = y1.reshape(-1)
+ dx = dx.reshape(-1)
+ dy = dy.reshape(-1)
+
+ valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht)
+ x1 = x1[valid]
+ y1 = y1[valid]
+ dx = dx[valid]
+ dy = dy[valid]
+
+ flow_x = interpolate.griddata(
+ (x1, y1), dx, (x0, y0), method='nearest', fill_value=0)
+
+ flow_y = interpolate.griddata(
+ (x1, y1), dy, (x0, y0), method='nearest', fill_value=0)
+
+ flow = np.stack([flow_x, flow_y], axis=0)
+ return torch.from_numpy(flow).float()
+
+
+def bilinear_sampler(img, coords, mode='bilinear', mask=False):
+ """ Wrapper for grid_sample, uses pixel coordinates """
+ H, W = img.shape[-2:]
+ xgrid, ygrid = coords.split([1,1], dim=-1)
+ xgrid = 2*xgrid/(W-1) - 1
+ ygrid = 2*ygrid/(H-1) - 1
+
+ grid = torch.cat([xgrid, ygrid], dim=-1)
+ img = F.grid_sample(img, grid, align_corners=True)
+
+ if mask:
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
+ return img, mask.float()
+
+ return img
+
+
+def coords_grid(batch, ht, wd):
+ coords = torch.meshgrid(torch.arange(ht), torch.arange(wd))
+ coords = torch.stack(coords[::-1], dim=0).float()
+ return coords[None].repeat(batch, 1, 1, 1)
+
+
+def upflow8(flow, mode='bilinear'):
+ new_size = (8 * flow.shape[2], 8 * flow.shape[3])
+ return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True)
diff --git a/propainter/core/dataset.py b/propainter/core/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..27b135bb7716f0e89d9a3ec9fd4411dfe3eb94eb
--- /dev/null
+++ b/propainter/core/dataset.py
@@ -0,0 +1,232 @@
+import os
+import json
+import random
+
+import cv2
+from PIL import Image
+import numpy as np
+
+import torch
+import torchvision.transforms as transforms
+
+from utils.file_client import FileClient
+from utils.img_util import imfrombytes
+from utils.flow_util import resize_flow, flowread
+from core.utils import (create_random_shape_with_random_motion, Stack,
+ ToTorchFormatTensor, GroupRandomHorizontalFlip,GroupRandomHorizontalFlowFlip)
+
+
+class TrainDataset(torch.utils.data.Dataset):
+ def __init__(self, args: dict):
+ self.args = args
+ self.video_root = args['video_root']
+ self.flow_root = args['flow_root']
+ self.num_local_frames = args['num_local_frames']
+ self.num_ref_frames = args['num_ref_frames']
+ self.size = self.w, self.h = (args['w'], args['h'])
+
+ self.load_flow = args['load_flow']
+ if self.load_flow:
+ assert os.path.exists(self.flow_root)
+
+ json_path = os.path.join('./datasets', args['name'], 'train.json')
+
+ with open(json_path, 'r') as f:
+ self.video_train_dict = json.load(f)
+ self.video_names = sorted(list(self.video_train_dict.keys()))
+
+ # self.video_names = sorted(os.listdir(self.video_root))
+ self.video_dict = {}
+ self.frame_dict = {}
+
+ for v in self.video_names:
+ frame_list = sorted(os.listdir(os.path.join(self.video_root, v)))
+ v_len = len(frame_list)
+ if v_len > self.num_local_frames + self.num_ref_frames:
+ self.video_dict[v] = v_len
+ self.frame_dict[v] = frame_list
+
+
+ self.video_names = list(self.video_dict.keys()) # update names
+
+ self._to_tensors = transforms.Compose([
+ Stack(),
+ ToTorchFormatTensor(),
+ ])
+ self.file_client = FileClient('disk')
+
+ def __len__(self):
+ return len(self.video_names)
+
+ def _sample_index(self, length, sample_length, num_ref_frame=3):
+ complete_idx_set = list(range(length))
+ pivot = random.randint(0, length - sample_length)
+ local_idx = complete_idx_set[pivot:pivot + sample_length]
+ remain_idx = list(set(complete_idx_set) - set(local_idx))
+ ref_index = sorted(random.sample(remain_idx, num_ref_frame))
+
+ return local_idx + ref_index
+
+ def __getitem__(self, index):
+ video_name = self.video_names[index]
+ # create masks
+ all_masks = create_random_shape_with_random_motion(
+ self.video_dict[video_name], imageHeight=self.h, imageWidth=self.w)
+
+ # create sample index
+ selected_index = self._sample_index(self.video_dict[video_name],
+ self.num_local_frames,
+ self.num_ref_frames)
+
+ # read video frames
+ frames = []
+ masks = []
+ flows_f, flows_b = [], []
+ for idx in selected_index:
+ frame_list = self.frame_dict[video_name]
+ img_path = os.path.join(self.video_root, video_name, frame_list[idx])
+ img_bytes = self.file_client.get(img_path, 'img')
+ img = imfrombytes(img_bytes, float32=False)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR)
+ img = Image.fromarray(img)
+
+ frames.append(img)
+ masks.append(all_masks[idx])
+
+ if len(frames) <= self.num_local_frames-1 and self.load_flow:
+ current_n = frame_list[idx][:-4]
+ next_n = frame_list[idx+1][:-4]
+ flow_f_path = os.path.join(self.flow_root, video_name, f'{current_n}_{next_n}_f.flo')
+ flow_b_path = os.path.join(self.flow_root, video_name, f'{next_n}_{current_n}_b.flo')
+ flow_f = flowread(flow_f_path, quantize=False)
+ flow_b = flowread(flow_b_path, quantize=False)
+ flow_f = resize_flow(flow_f, self.h, self.w)
+ flow_b = resize_flow(flow_b, self.h, self.w)
+ flows_f.append(flow_f)
+ flows_b.append(flow_b)
+
+ if len(frames) == self.num_local_frames: # random reverse
+ if random.random() < 0.5:
+ frames.reverse()
+ masks.reverse()
+ if self.load_flow:
+ flows_f.reverse()
+ flows_b.reverse()
+ flows_ = flows_f
+ flows_f = flows_b
+ flows_b = flows_
+
+ if self.load_flow:
+ frames, flows_f, flows_b = GroupRandomHorizontalFlowFlip()(frames, flows_f, flows_b)
+ else:
+ frames = GroupRandomHorizontalFlip()(frames)
+
+ # normalizate, to tensors
+ frame_tensors = self._to_tensors(frames) * 2.0 - 1.0
+ mask_tensors = self._to_tensors(masks)
+ if self.load_flow:
+ flows_f = np.stack(flows_f, axis=-1) # H W 2 T-1
+ flows_b = np.stack(flows_b, axis=-1)
+ flows_f = torch.from_numpy(flows_f).permute(3, 2, 0, 1).contiguous().float()
+ flows_b = torch.from_numpy(flows_b).permute(3, 2, 0, 1).contiguous().float()
+
+ # img [-1,1] mask [0,1]
+ if self.load_flow:
+ return frame_tensors, mask_tensors, flows_f, flows_b, video_name
+ else:
+ return frame_tensors, mask_tensors, 'None', 'None', video_name
+
+
+class TestDataset(torch.utils.data.Dataset):
+ def __init__(self, args):
+ self.args = args
+ self.size = self.w, self.h = args['size']
+
+ self.video_root = args['video_root']
+ self.mask_root = args['mask_root']
+ self.flow_root = args['flow_root']
+
+ self.load_flow = args['load_flow']
+ if self.load_flow:
+ assert os.path.exists(self.flow_root)
+ self.video_names = sorted(os.listdir(self.mask_root))
+
+ self.video_dict = {}
+ self.frame_dict = {}
+
+ for v in self.video_names:
+ frame_list = sorted(os.listdir(os.path.join(self.video_root, v)))
+ v_len = len(frame_list)
+ self.video_dict[v] = v_len
+ self.frame_dict[v] = frame_list
+
+ self._to_tensors = transforms.Compose([
+ Stack(),
+ ToTorchFormatTensor(),
+ ])
+ self.file_client = FileClient('disk')
+
+ def __len__(self):
+ return len(self.video_names)
+
+ def __getitem__(self, index):
+ video_name = self.video_names[index]
+ selected_index = list(range(self.video_dict[video_name]))
+
+ # read video frames
+ frames = []
+ masks = []
+ flows_f, flows_b = [], []
+ for idx in selected_index:
+ frame_list = self.frame_dict[video_name]
+ frame_path = os.path.join(self.video_root, video_name, frame_list[idx])
+
+ img_bytes = self.file_client.get(frame_path, 'input')
+ img = imfrombytes(img_bytes, float32=False)
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR)
+ img = Image.fromarray(img)
+
+ frames.append(img)
+
+ mask_path = os.path.join(self.mask_root, video_name, str(idx).zfill(5) + '.png')
+ mask = Image.open(mask_path).resize(self.size, Image.NEAREST).convert('L')
+
+ # origin: 0 indicates missing. now: 1 indicates missing
+ mask = np.asarray(mask)
+ m = np.array(mask > 0).astype(np.uint8)
+
+ m = cv2.dilate(m,
+ cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)),
+ iterations=4)
+ mask = Image.fromarray(m * 255)
+ masks.append(mask)
+
+ if len(frames) <= len(selected_index)-1 and self.load_flow:
+ current_n = frame_list[idx][:-4]
+ next_n = frame_list[idx+1][:-4]
+ flow_f_path = os.path.join(self.flow_root, video_name, f'{current_n}_{next_n}_f.flo')
+ flow_b_path = os.path.join(self.flow_root, video_name, f'{next_n}_{current_n}_b.flo')
+ flow_f = flowread(flow_f_path, quantize=False)
+ flow_b = flowread(flow_b_path, quantize=False)
+ flow_f = resize_flow(flow_f, self.h, self.w)
+ flow_b = resize_flow(flow_b, self.h, self.w)
+ flows_f.append(flow_f)
+ flows_b.append(flow_b)
+
+ # normalizate, to tensors
+ frames_PIL = [np.array(f).astype(np.uint8) for f in frames]
+ frame_tensors = self._to_tensors(frames) * 2.0 - 1.0
+ mask_tensors = self._to_tensors(masks)
+
+ if self.load_flow:
+ flows_f = np.stack(flows_f, axis=-1) # H W 2 T-1
+ flows_b = np.stack(flows_b, axis=-1)
+ flows_f = torch.from_numpy(flows_f).permute(3, 2, 0, 1).contiguous().float()
+ flows_b = torch.from_numpy(flows_b).permute(3, 2, 0, 1).contiguous().float()
+
+ if self.load_flow:
+ return frame_tensors, mask_tensors, flows_f, flows_b, video_name, frames_PIL
+ else:
+ return frame_tensors, mask_tensors, 'None', 'None', video_name
\ No newline at end of file
diff --git a/propainter/core/dist.py b/propainter/core/dist.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e4e9e670a3b853fac345618d3557d648d813902
--- /dev/null
+++ b/propainter/core/dist.py
@@ -0,0 +1,47 @@
+import os
+import torch
+
+
+def get_world_size():
+ """Find OMPI world size without calling mpi functions
+ :rtype: int
+ """
+ if os.environ.get('PMI_SIZE') is not None:
+ return int(os.environ.get('PMI_SIZE') or 1)
+ elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None:
+ return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1)
+ else:
+ return torch.cuda.device_count()
+
+
+def get_global_rank():
+ """Find OMPI world rank without calling mpi functions
+ :rtype: int
+ """
+ if os.environ.get('PMI_RANK') is not None:
+ return int(os.environ.get('PMI_RANK') or 0)
+ elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None:
+ return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0)
+ else:
+ return 0
+
+
+def get_local_rank():
+ """Find OMPI local rank without calling mpi functions
+ :rtype: int
+ """
+ if os.environ.get('MPI_LOCALRANKID') is not None:
+ return int(os.environ.get('MPI_LOCALRANKID') or 0)
+ elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None:
+ return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0)
+ else:
+ return 0
+
+
+def get_master_ip():
+ if os.environ.get('AZ_BATCH_MASTER_NODE') is not None:
+ return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0]
+ elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None:
+ return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE')
+ else:
+ return "127.0.0.1"
diff --git a/propainter/core/loss.py b/propainter/core/loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1d94d0ce9433b66ce2dce7adb24acb16051e8da
--- /dev/null
+++ b/propainter/core/loss.py
@@ -0,0 +1,180 @@
+import torch
+import torch.nn as nn
+import lpips
+from model.vgg_arch import VGGFeatureExtractor
+
+class PerceptualLoss(nn.Module):
+ """Perceptual loss with commonly used style loss.
+
+ Args:
+ layer_weights (dict): The weight for each layer of vgg feature.
+ Here is an example: {'conv5_4': 1.}, which means the conv5_4
+ feature layer (before relu5_4) will be extracted with weight
+ 1.0 in calculting losses.
+ vgg_type (str): The type of vgg network used as feature extractor.
+ Default: 'vgg19'.
+ use_input_norm (bool): If True, normalize the input image in vgg.
+ Default: True.
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
+ Default: False.
+ perceptual_weight (float): If `perceptual_weight > 0`, the perceptual
+ loss will be calculated and the loss will multiplied by the
+ weight. Default: 1.0.
+ style_weight (float): If `style_weight > 0`, the style loss will be
+ calculated and the loss will multiplied by the weight.
+ Default: 0.
+ criterion (str): Criterion used for perceptual loss. Default: 'l1'.
+ """
+
+ def __init__(self,
+ layer_weights,
+ vgg_type='vgg19',
+ use_input_norm=True,
+ range_norm=False,
+ perceptual_weight=1.0,
+ style_weight=0.,
+ criterion='l1'):
+ super(PerceptualLoss, self).__init__()
+ self.perceptual_weight = perceptual_weight
+ self.style_weight = style_weight
+ self.layer_weights = layer_weights
+ self.vgg = VGGFeatureExtractor(
+ layer_name_list=list(layer_weights.keys()),
+ vgg_type=vgg_type,
+ use_input_norm=use_input_norm,
+ range_norm=range_norm)
+
+ self.criterion_type = criterion
+ if self.criterion_type == 'l1':
+ self.criterion = torch.nn.L1Loss()
+ elif self.criterion_type == 'l2':
+ self.criterion = torch.nn.L2loss()
+ elif self.criterion_type == 'mse':
+ self.criterion = torch.nn.MSELoss(reduction='mean')
+ elif self.criterion_type == 'fro':
+ self.criterion = None
+ else:
+ raise NotImplementedError(f'{criterion} criterion has not been supported.')
+
+ def forward(self, x, gt):
+ """Forward function.
+
+ Args:
+ x (Tensor): Input tensor with shape (n, c, h, w).
+ gt (Tensor): Ground-truth tensor with shape (n, c, h, w).
+
+ Returns:
+ Tensor: Forward results.
+ """
+ # extract vgg features
+ x_features = self.vgg(x)
+ gt_features = self.vgg(gt.detach())
+
+ # calculate perceptual loss
+ if self.perceptual_weight > 0:
+ percep_loss = 0
+ for k in x_features.keys():
+ if self.criterion_type == 'fro':
+ percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k]
+ else:
+ percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k]
+ percep_loss *= self.perceptual_weight
+ else:
+ percep_loss = None
+
+ # calculate style loss
+ if self.style_weight > 0:
+ style_loss = 0
+ for k in x_features.keys():
+ if self.criterion_type == 'fro':
+ style_loss += torch.norm(
+ self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
+ else:
+ style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
+ gt_features[k])) * self.layer_weights[k]
+ style_loss *= self.style_weight
+ else:
+ style_loss = None
+
+ return percep_loss, style_loss
+
+ def _gram_mat(self, x):
+ """Calculate Gram matrix.
+
+ Args:
+ x (torch.Tensor): Tensor with shape of (n, c, h, w).
+
+ Returns:
+ torch.Tensor: Gram matrix.
+ """
+ n, c, h, w = x.size()
+ features = x.view(n, c, w * h)
+ features_t = features.transpose(1, 2)
+ gram = features.bmm(features_t) / (c * h * w)
+ return gram
+
+class LPIPSLoss(nn.Module):
+ def __init__(self,
+ loss_weight=1.0,
+ use_input_norm=True,
+ range_norm=False,):
+ super(LPIPSLoss, self).__init__()
+ self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval()
+ self.loss_weight = loss_weight
+ self.use_input_norm = use_input_norm
+ self.range_norm = range_norm
+
+ if self.use_input_norm:
+ # the mean is for image with range [0, 1]
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ # the std is for image with range [0, 1]
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def forward(self, pred, target):
+ if self.range_norm:
+ pred = (pred + 1) / 2
+ target = (target + 1) / 2
+ if self.use_input_norm:
+ pred = (pred - self.mean) / self.std
+ target = (target - self.mean) / self.std
+ lpips_loss = self.perceptual(target.contiguous(), pred.contiguous())
+ return self.loss_weight * lpips_loss.mean(), None
+
+
+class AdversarialLoss(nn.Module):
+ r"""
+ Adversarial loss
+ https://arxiv.org/abs/1711.10337
+ """
+ def __init__(self,
+ type='nsgan',
+ target_real_label=1.0,
+ target_fake_label=0.0):
+ r"""
+ type = nsgan | lsgan | hinge
+ """
+ super(AdversarialLoss, self).__init__()
+ self.type = type
+ self.register_buffer('real_label', torch.tensor(target_real_label))
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
+
+ if type == 'nsgan':
+ self.criterion = nn.BCELoss()
+ elif type == 'lsgan':
+ self.criterion = nn.MSELoss()
+ elif type == 'hinge':
+ self.criterion = nn.ReLU()
+
+ def __call__(self, outputs, is_real, is_disc=None):
+ if self.type == 'hinge':
+ if is_disc:
+ if is_real:
+ outputs = -outputs
+ return self.criterion(1 + outputs).mean()
+ else:
+ return (-outputs).mean()
+ else:
+ labels = (self.real_label
+ if is_real else self.fake_label).expand_as(outputs)
+ loss = self.criterion(outputs, labels)
+ return loss
diff --git a/propainter/core/lr_scheduler.py b/propainter/core/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bd1341cdcc64aa1c2a416b837551590ded4a43d
--- /dev/null
+++ b/propainter/core/lr_scheduler.py
@@ -0,0 +1,112 @@
+"""
+ LR scheduler from BasicSR https://github.com/xinntao/BasicSR
+"""
+import math
+from collections import Counter
+from torch.optim.lr_scheduler import _LRScheduler
+
+
+class MultiStepRestartLR(_LRScheduler):
+ """ MultiStep with restarts learning rate scheme.
+ Args:
+ optimizer (torch.nn.optimizer): Torch optimizer.
+ milestones (list): Iterations that will decrease learning rate.
+ gamma (float): Decrease ratio. Default: 0.1.
+ restarts (list): Restart iterations. Default: [0].
+ restart_weights (list): Restart weights at each restart iteration.
+ Default: [1].
+ last_epoch (int): Used in _LRScheduler. Default: -1.
+ """
+ def __init__(self,
+ optimizer,
+ milestones,
+ gamma=0.1,
+ restarts=(0, ),
+ restart_weights=(1, ),
+ last_epoch=-1):
+ self.milestones = Counter(milestones)
+ self.gamma = gamma
+ self.restarts = restarts
+ self.restart_weights = restart_weights
+ assert len(self.restarts) == len(
+ self.restart_weights), 'restarts and their weights do not match.'
+ super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ if self.last_epoch in self.restarts:
+ weight = self.restart_weights[self.restarts.index(self.last_epoch)]
+ return [
+ group['initial_lr'] * weight
+ for group in self.optimizer.param_groups
+ ]
+ if self.last_epoch not in self.milestones:
+ return [group['lr'] for group in self.optimizer.param_groups]
+ return [
+ group['lr'] * self.gamma**self.milestones[self.last_epoch]
+ for group in self.optimizer.param_groups
+ ]
+
+
+def get_position_from_periods(iteration, cumulative_period):
+ """Get the position from a period list.
+ It will return the index of the right-closest number in the period list.
+ For example, the cumulative_period = [100, 200, 300, 400],
+ if iteration == 50, return 0;
+ if iteration == 210, return 2;
+ if iteration == 300, return 2.
+ Args:
+ iteration (int): Current iteration.
+ cumulative_period (list[int]): Cumulative period list.
+ Returns:
+ int: The position of the right-closest number in the period list.
+ """
+ for i, period in enumerate(cumulative_period):
+ if iteration <= period:
+ return i
+
+
+class CosineAnnealingRestartLR(_LRScheduler):
+ """ Cosine annealing with restarts learning rate scheme.
+ An example of config:
+ periods = [10, 10, 10, 10]
+ restart_weights = [1, 0.5, 0.5, 0.5]
+ eta_min=1e-7
+ It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
+ scheduler will restart with the weights in restart_weights.
+ Args:
+ optimizer (torch.nn.optimizer): Torch optimizer.
+ periods (list): Period for each cosine anneling cycle.
+ restart_weights (list): Restart weights at each restart iteration.
+ Default: [1].
+ eta_min (float): The mimimum lr. Default: 0.
+ last_epoch (int): Used in _LRScheduler. Default: -1.
+ """
+ def __init__(self,
+ optimizer,
+ periods,
+ restart_weights=(1, ),
+ eta_min=1e-7,
+ last_epoch=-1):
+ self.periods = periods
+ self.restart_weights = restart_weights
+ self.eta_min = eta_min
+ assert (len(self.periods) == len(self.restart_weights)
+ ), 'periods and restart_weights should have the same length.'
+ self.cumulative_period = [
+ sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
+ ]
+ super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
+
+ def get_lr(self):
+ idx = get_position_from_periods(self.last_epoch,
+ self.cumulative_period)
+ current_weight = self.restart_weights[idx]
+ nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
+ current_period = self.periods[idx]
+
+ return [
+ self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
+ (1 + math.cos(math.pi * (
+ (self.last_epoch - nearest_restart) / current_period)))
+ for base_lr in self.base_lrs
+ ]
diff --git a/propainter/core/metrics.py b/propainter/core/metrics.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5695c77249936ee1443c5da6ff90b14ee439a00
--- /dev/null
+++ b/propainter/core/metrics.py
@@ -0,0 +1,571 @@
+import numpy as np
+# from skimage import measure
+from skimage.metrics import structural_similarity as compare_ssim
+from scipy import linalg
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from propainter.core.utils import to_tensors
+
+
+def calculate_epe(flow1, flow2):
+ """Calculate End point errors."""
+
+ epe = torch.sum((flow1 - flow2)**2, dim=1).sqrt()
+ epe = epe.view(-1)
+ return epe.mean().item()
+
+
+def calculate_psnr(img1, img2):
+ """Calculate PSNR (Peak Signal-to-Noise Ratio).
+ Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
+ Args:
+ img1 (ndarray): Images with range [0, 255].
+ img2 (ndarray): Images with range [0, 255].
+ Returns:
+ float: psnr result.
+ """
+
+ assert img1.shape == img2.shape, \
+ (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
+
+ mse = np.mean((img1 - img2)**2)
+ if mse == 0:
+ return float('inf')
+ return 20. * np.log10(255. / np.sqrt(mse))
+
+
+def calc_psnr_and_ssim(img1, img2):
+ """Calculate PSNR and SSIM for images.
+ img1: ndarray, range [0, 255]
+ img2: ndarray, range [0, 255]
+ """
+ img1 = img1.astype(np.float64)
+ img2 = img2.astype(np.float64)
+
+ psnr = calculate_psnr(img1, img2)
+ ssim = compare_ssim(img1,
+ img2,
+ data_range=255,
+ multichannel=True,
+ win_size=65,
+ channel_axis=2)
+
+ return psnr, ssim
+
+
+###########################
+# I3D models
+###########################
+
+
+def init_i3d_model(i3d_model_path):
+ print(f"[Loading I3D model from {i3d_model_path} for FID score ..]")
+ i3d_model = InceptionI3d(400, in_channels=3, final_endpoint='Logits')
+ i3d_model.load_state_dict(torch.load(i3d_model_path))
+ i3d_model.to(torch.device('cuda:0'))
+ return i3d_model
+
+
+def calculate_i3d_activations(video1, video2, i3d_model, device):
+ """Calculate VFID metric.
+ video1: list[PIL.Image]
+ video2: list[PIL.Image]
+ """
+ video1 = to_tensors()(video1).unsqueeze(0).to(device)
+ video2 = to_tensors()(video2).unsqueeze(0).to(device)
+ video1_activations = get_i3d_activations(
+ video1, i3d_model).cpu().numpy().flatten()
+ video2_activations = get_i3d_activations(
+ video2, i3d_model).cpu().numpy().flatten()
+
+ return video1_activations, video2_activations
+
+
+def calculate_vfid(real_activations, fake_activations):
+ """
+ Given two distribution of features, compute the FID score between them
+ Params:
+ real_activations: list[ndarray]
+ fake_activations: list[ndarray]
+ """
+ m1 = np.mean(real_activations, axis=0)
+ m2 = np.mean(fake_activations, axis=0)
+ s1 = np.cov(real_activations, rowvar=False)
+ s2 = np.cov(fake_activations, rowvar=False)
+ return calculate_frechet_distance(m1, s1, m2, s2)
+
+
+def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
+ """Numpy implementation of the Frechet Distance.
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
+ and X_2 ~ N(mu_2, C_2) is
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
+ Stable version by Dougal J. Sutherland.
+ Params:
+ -- mu1 : Numpy array containing the activations of a layer of the
+ inception net (like returned by the function 'get_predictions')
+ for generated samples.
+ -- mu2 : The sample mean over activations, precalculated on an
+ representive data set.
+ -- sigma1: The covariance matrix over activations for generated samples.
+ -- sigma2: The covariance matrix over activations, precalculated on an
+ representive data set.
+ Returns:
+ -- : The Frechet Distance.
+ """
+
+ mu1 = np.atleast_1d(mu1)
+ mu2 = np.atleast_1d(mu2)
+
+ sigma1 = np.atleast_2d(sigma1)
+ sigma2 = np.atleast_2d(sigma2)
+
+ assert mu1.shape == mu2.shape, \
+ 'Training and test mean vectors have different lengths'
+ assert sigma1.shape == sigma2.shape, \
+ 'Training and test covariances have different dimensions'
+
+ diff = mu1 - mu2
+
+ # Product might be almost singular
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
+ if not np.isfinite(covmean).all():
+ msg = ('fid calculation produces singular product; '
+ 'adding %s to diagonal of cov estimates') % eps
+ print(msg)
+ offset = np.eye(sigma1.shape[0]) * eps
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
+
+ # Numerical error might give slight imaginary component
+ if np.iscomplexobj(covmean):
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
+ m = np.max(np.abs(covmean.imag))
+ raise ValueError('Imaginary component {}'.format(m))
+ covmean = covmean.real
+
+ tr_covmean = np.trace(covmean)
+
+ return (diff.dot(diff) + np.trace(sigma1) + # NOQA
+ np.trace(sigma2) - 2 * tr_covmean)
+
+
+def get_i3d_activations(batched_video,
+ i3d_model,
+ target_endpoint='Logits',
+ flatten=True,
+ grad_enabled=False):
+ """
+ Get features from i3d model and flatten them to 1d feature,
+ valid target endpoints are defined in InceptionI3d.VALID_ENDPOINTS
+ VALID_ENDPOINTS = (
+ 'Conv3d_1a_7x7',
+ 'MaxPool3d_2a_3x3',
+ 'Conv3d_2b_1x1',
+ 'Conv3d_2c_3x3',
+ 'MaxPool3d_3a_3x3',
+ 'Mixed_3b',
+ 'Mixed_3c',
+ 'MaxPool3d_4a_3x3',
+ 'Mixed_4b',
+ 'Mixed_4c',
+ 'Mixed_4d',
+ 'Mixed_4e',
+ 'Mixed_4f',
+ 'MaxPool3d_5a_2x2',
+ 'Mixed_5b',
+ 'Mixed_5c',
+ 'Logits',
+ 'Predictions',
+ )
+ """
+ with torch.set_grad_enabled(grad_enabled):
+ feat = i3d_model.extract_features(batched_video.transpose(1, 2),
+ target_endpoint)
+ if flatten:
+ feat = feat.view(feat.size(0), -1)
+
+ return feat
+
+
+# This code is from https://github.com/piergiaj/pytorch-i3d/blob/master/pytorch_i3d.py
+# I only fix flake8 errors and do some cleaning here
+
+
+class MaxPool3dSamePadding(nn.MaxPool3d):
+ def compute_pad(self, dim, s):
+ if s % self.stride[dim] == 0:
+ return max(self.kernel_size[dim] - self.stride[dim], 0)
+ else:
+ return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)
+
+ def forward(self, x):
+ # compute 'same' padding
+ (batch, channel, t, h, w) = x.size()
+ pad_t = self.compute_pad(0, t)
+ pad_h = self.compute_pad(1, h)
+ pad_w = self.compute_pad(2, w)
+
+ pad_t_f = pad_t // 2
+ pad_t_b = pad_t - pad_t_f
+ pad_h_f = pad_h // 2
+ pad_h_b = pad_h - pad_h_f
+ pad_w_f = pad_w // 2
+ pad_w_b = pad_w - pad_w_f
+
+ pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
+ x = F.pad(x, pad)
+ return super(MaxPool3dSamePadding, self).forward(x)
+
+
+class Unit3D(nn.Module):
+ def __init__(self,
+ in_channels,
+ output_channels,
+ kernel_shape=(1, 1, 1),
+ stride=(1, 1, 1),
+ padding=0,
+ activation_fn=F.relu,
+ use_batch_norm=True,
+ use_bias=False,
+ name='unit_3d'):
+ """Initializes Unit3D module."""
+ super(Unit3D, self).__init__()
+
+ self._output_channels = output_channels
+ self._kernel_shape = kernel_shape
+ self._stride = stride
+ self._use_batch_norm = use_batch_norm
+ self._activation_fn = activation_fn
+ self._use_bias = use_bias
+ self.name = name
+ self.padding = padding
+
+ self.conv3d = nn.Conv3d(
+ in_channels=in_channels,
+ out_channels=self._output_channels,
+ kernel_size=self._kernel_shape,
+ stride=self._stride,
+ padding=0, # we always want padding to be 0 here. We will
+ # dynamically pad based on input size in forward function
+ bias=self._use_bias)
+
+ if self._use_batch_norm:
+ self.bn = nn.BatchNorm3d(self._output_channels,
+ eps=0.001,
+ momentum=0.01)
+
+ def compute_pad(self, dim, s):
+ if s % self._stride[dim] == 0:
+ return max(self._kernel_shape[dim] - self._stride[dim], 0)
+ else:
+ return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)
+
+ def forward(self, x):
+ # compute 'same' padding
+ (batch, channel, t, h, w) = x.size()
+ pad_t = self.compute_pad(0, t)
+ pad_h = self.compute_pad(1, h)
+ pad_w = self.compute_pad(2, w)
+
+ pad_t_f = pad_t // 2
+ pad_t_b = pad_t - pad_t_f
+ pad_h_f = pad_h // 2
+ pad_h_b = pad_h - pad_h_f
+ pad_w_f = pad_w // 2
+ pad_w_b = pad_w - pad_w_f
+
+ pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
+ x = F.pad(x, pad)
+
+ x = self.conv3d(x)
+ if self._use_batch_norm:
+ x = self.bn(x)
+ if self._activation_fn is not None:
+ x = self._activation_fn(x)
+ return x
+
+
+class InceptionModule(nn.Module):
+ def __init__(self, in_channels, out_channels, name):
+ super(InceptionModule, self).__init__()
+
+ self.b0 = Unit3D(in_channels=in_channels,
+ output_channels=out_channels[0],
+ kernel_shape=[1, 1, 1],
+ padding=0,
+ name=name + '/Branch_0/Conv3d_0a_1x1')
+ self.b1a = Unit3D(in_channels=in_channels,
+ output_channels=out_channels[1],
+ kernel_shape=[1, 1, 1],
+ padding=0,
+ name=name + '/Branch_1/Conv3d_0a_1x1')
+ self.b1b = Unit3D(in_channels=out_channels[1],
+ output_channels=out_channels[2],
+ kernel_shape=[3, 3, 3],
+ name=name + '/Branch_1/Conv3d_0b_3x3')
+ self.b2a = Unit3D(in_channels=in_channels,
+ output_channels=out_channels[3],
+ kernel_shape=[1, 1, 1],
+ padding=0,
+ name=name + '/Branch_2/Conv3d_0a_1x1')
+ self.b2b = Unit3D(in_channels=out_channels[3],
+ output_channels=out_channels[4],
+ kernel_shape=[3, 3, 3],
+ name=name + '/Branch_2/Conv3d_0b_3x3')
+ self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3],
+ stride=(1, 1, 1),
+ padding=0)
+ self.b3b = Unit3D(in_channels=in_channels,
+ output_channels=out_channels[5],
+ kernel_shape=[1, 1, 1],
+ padding=0,
+ name=name + '/Branch_3/Conv3d_0b_1x1')
+ self.name = name
+
+ def forward(self, x):
+ b0 = self.b0(x)
+ b1 = self.b1b(self.b1a(x))
+ b2 = self.b2b(self.b2a(x))
+ b3 = self.b3b(self.b3a(x))
+ return torch.cat([b0, b1, b2, b3], dim=1)
+
+
+class InceptionI3d(nn.Module):
+ """Inception-v1 I3D architecture.
+ The model is introduced in:
+ Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset
+ Joao Carreira, Andrew Zisserman
+ https://arxiv.org/pdf/1705.07750v1.pdf.
+ See also the Inception architecture, introduced in:
+ Going deeper with convolutions
+ Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
+ Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
+ http://arxiv.org/pdf/1409.4842v1.pdf.
+ """
+
+ # Endpoints of the model in order. During construction, all the endpoints up
+ # to a designated `final_endpoint` are returned in a dictionary as the
+ # second return value.
+ VALID_ENDPOINTS = (
+ 'Conv3d_1a_7x7',
+ 'MaxPool3d_2a_3x3',
+ 'Conv3d_2b_1x1',
+ 'Conv3d_2c_3x3',
+ 'MaxPool3d_3a_3x3',
+ 'Mixed_3b',
+ 'Mixed_3c',
+ 'MaxPool3d_4a_3x3',
+ 'Mixed_4b',
+ 'Mixed_4c',
+ 'Mixed_4d',
+ 'Mixed_4e',
+ 'Mixed_4f',
+ 'MaxPool3d_5a_2x2',
+ 'Mixed_5b',
+ 'Mixed_5c',
+ 'Logits',
+ 'Predictions',
+ )
+
+ def __init__(self,
+ num_classes=400,
+ spatial_squeeze=True,
+ final_endpoint='Logits',
+ name='inception_i3d',
+ in_channels=3,
+ dropout_keep_prob=0.5):
+ """Initializes I3D model instance.
+ Args:
+ num_classes: The number of outputs in the logit layer (default 400, which
+ matches the Kinetics dataset).
+ spatial_squeeze: Whether to squeeze the spatial dimensions for the logits
+ before returning (default True).
+ final_endpoint: The model contains many possible endpoints.
+ `final_endpoint` specifies the last endpoint for the model to be built
+ up to. In addition to the output at `final_endpoint`, all the outputs
+ at endpoints up to `final_endpoint` will also be returned, in a
+ dictionary. `final_endpoint` must be one of
+ InceptionI3d.VALID_ENDPOINTS (default 'Logits').
+ name: A string (optional). The name of this module.
+ Raises:
+ ValueError: if `final_endpoint` is not recognized.
+ """
+
+ if final_endpoint not in self.VALID_ENDPOINTS:
+ raise ValueError('Unknown final endpoint %s' % final_endpoint)
+
+ super(InceptionI3d, self).__init__()
+ self._num_classes = num_classes
+ self._spatial_squeeze = spatial_squeeze
+ self._final_endpoint = final_endpoint
+ self.logits = None
+
+ if self._final_endpoint not in self.VALID_ENDPOINTS:
+ raise ValueError('Unknown final endpoint %s' %
+ self._final_endpoint)
+
+ self.end_points = {}
+ end_point = 'Conv3d_1a_7x7'
+ self.end_points[end_point] = Unit3D(in_channels=in_channels,
+ output_channels=64,
+ kernel_shape=[7, 7, 7],
+ stride=(2, 2, 2),
+ padding=(3, 3, 3),
+ name=name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'MaxPool3d_2a_3x3'
+ self.end_points[end_point] = MaxPool3dSamePadding(
+ kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Conv3d_2b_1x1'
+ self.end_points[end_point] = Unit3D(in_channels=64,
+ output_channels=64,
+ kernel_shape=[1, 1, 1],
+ padding=0,
+ name=name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Conv3d_2c_3x3'
+ self.end_points[end_point] = Unit3D(in_channels=64,
+ output_channels=192,
+ kernel_shape=[3, 3, 3],
+ padding=1,
+ name=name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'MaxPool3d_3a_3x3'
+ self.end_points[end_point] = MaxPool3dSamePadding(
+ kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Mixed_3b'
+ self.end_points[end_point] = InceptionModule(192,
+ [64, 96, 128, 16, 32, 32],
+ name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Mixed_3c'
+ self.end_points[end_point] = InceptionModule(
+ 256, [128, 128, 192, 32, 96, 64], name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'MaxPool3d_4a_3x3'
+ self.end_points[end_point] = MaxPool3dSamePadding(
+ kernel_size=[3, 3, 3], stride=(2, 2, 2), padding=0)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Mixed_4b'
+ self.end_points[end_point] = InceptionModule(
+ 128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Mixed_4c'
+ self.end_points[end_point] = InceptionModule(
+ 192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Mixed_4d'
+ self.end_points[end_point] = InceptionModule(
+ 160 + 224 + 64 + 64, [128, 128, 256, 24, 64, 64], name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Mixed_4e'
+ self.end_points[end_point] = InceptionModule(
+ 128 + 256 + 64 + 64, [112, 144, 288, 32, 64, 64], name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Mixed_4f'
+ self.end_points[end_point] = InceptionModule(
+ 112 + 288 + 64 + 64, [256, 160, 320, 32, 128, 128],
+ name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'MaxPool3d_5a_2x2'
+ self.end_points[end_point] = MaxPool3dSamePadding(
+ kernel_size=[2, 2, 2], stride=(2, 2, 2), padding=0)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Mixed_5b'
+ self.end_points[end_point] = InceptionModule(
+ 256 + 320 + 128 + 128, [256, 160, 320, 32, 128, 128],
+ name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Mixed_5c'
+ self.end_points[end_point] = InceptionModule(
+ 256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128],
+ name + end_point)
+ if self._final_endpoint == end_point:
+ return
+
+ end_point = 'Logits'
+ self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7], stride=(1, 1, 1))
+ self.dropout = nn.Dropout(dropout_keep_prob)
+ self.logits = Unit3D(in_channels=384 + 384 + 128 + 128,
+ output_channels=self._num_classes,
+ kernel_shape=[1, 1, 1],
+ padding=0,
+ activation_fn=None,
+ use_batch_norm=False,
+ use_bias=True,
+ name='logits')
+
+ self.build()
+
+ def replace_logits(self, num_classes):
+ self._num_classes = num_classes
+ self.logits = Unit3D(in_channels=384 + 384 + 128 + 128,
+ output_channels=self._num_classes,
+ kernel_shape=[1, 1, 1],
+ padding=0,
+ activation_fn=None,
+ use_batch_norm=False,
+ use_bias=True,
+ name='logits')
+
+ def build(self):
+ for k in self.end_points.keys():
+ self.add_module(k, self.end_points[k])
+
+ def forward(self, x):
+ for end_point in self.VALID_ENDPOINTS:
+ if end_point in self.end_points:
+ x = self._modules[end_point](
+ x) # use _modules to work with dataparallel
+
+ x = self.logits(self.dropout(self.avg_pool(x)))
+ if self._spatial_squeeze:
+ logits = x.squeeze(3).squeeze(3)
+ # logits is batch X time X classes, which is what we want to work with
+ return logits
+
+ def extract_features(self, x, target_endpoint='Logits'):
+ for end_point in self.VALID_ENDPOINTS:
+ if end_point in self.end_points:
+ x = self._modules[end_point](x)
+ if end_point == target_endpoint:
+ break
+ if target_endpoint == 'Logits':
+ return x.mean(4).mean(3).mean(2)
+ else:
+ return x
diff --git a/propainter/core/prefetch_dataloader.py b/propainter/core/prefetch_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..5088425050d4cc98114a9b93eb50ea60273f35a0
--- /dev/null
+++ b/propainter/core/prefetch_dataloader.py
@@ -0,0 +1,125 @@
+import queue as Queue
+import threading
+import torch
+from torch.utils.data import DataLoader
+
+
+class PrefetchGenerator(threading.Thread):
+ """A general prefetch generator.
+
+ Ref:
+ https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
+
+ Args:
+ generator: Python generator.
+ num_prefetch_queue (int): Number of prefetch queue.
+ """
+
+ def __init__(self, generator, num_prefetch_queue):
+ threading.Thread.__init__(self)
+ self.queue = Queue.Queue(num_prefetch_queue)
+ self.generator = generator
+ self.daemon = True
+ self.start()
+
+ def run(self):
+ for item in self.generator:
+ self.queue.put(item)
+ self.queue.put(None)
+
+ def __next__(self):
+ next_item = self.queue.get()
+ if next_item is None:
+ raise StopIteration
+ return next_item
+
+ def __iter__(self):
+ return self
+
+
+class PrefetchDataLoader(DataLoader):
+ """Prefetch version of dataloader.
+
+ Ref:
+ https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
+
+ TODO:
+ Need to test on single gpu and ddp (multi-gpu). There is a known issue in
+ ddp.
+
+ Args:
+ num_prefetch_queue (int): Number of prefetch queue.
+ kwargs (dict): Other arguments for dataloader.
+ """
+
+ def __init__(self, num_prefetch_queue, **kwargs):
+ self.num_prefetch_queue = num_prefetch_queue
+ super(PrefetchDataLoader, self).__init__(**kwargs)
+
+ def __iter__(self):
+ return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
+
+
+class CPUPrefetcher():
+ """CPU prefetcher.
+
+ Args:
+ loader: Dataloader.
+ """
+
+ def __init__(self, loader):
+ self.ori_loader = loader
+ self.loader = iter(loader)
+
+ def next(self):
+ try:
+ return next(self.loader)
+ except StopIteration:
+ return None
+
+ def reset(self):
+ self.loader = iter(self.ori_loader)
+
+
+class CUDAPrefetcher():
+ """CUDA prefetcher.
+
+ Ref:
+ https://github.com/NVIDIA/apex/issues/304#
+
+ It may consums more GPU memory.
+
+ Args:
+ loader: Dataloader.
+ opt (dict): Options.
+ """
+
+ def __init__(self, loader, opt):
+ self.ori_loader = loader
+ self.loader = iter(loader)
+ self.opt = opt
+ self.stream = torch.cuda.Stream()
+ self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
+ self.preload()
+
+ def preload(self):
+ try:
+ self.batch = next(self.loader) # self.batch is a dict
+ except StopIteration:
+ self.batch = None
+ return None
+ # put tensors to gpu
+ with torch.cuda.stream(self.stream):
+ for k, v in self.batch.items():
+ if torch.is_tensor(v):
+ self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True)
+
+ def next(self):
+ torch.cuda.current_stream().wait_stream(self.stream)
+ batch = self.batch
+ self.preload()
+ return batch
+
+ def reset(self):
+ self.loader = iter(self.ori_loader)
+ self.preload()
diff --git a/propainter/core/trainer.py b/propainter/core/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e6b6a669e92131697fcb12fd548509ce1b81080
--- /dev/null
+++ b/propainter/core/trainer.py
@@ -0,0 +1,509 @@
+import os
+import glob
+import logging
+import importlib
+from tqdm import tqdm
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from core.prefetch_dataloader import PrefetchDataLoader, CPUPrefetcher
+from torch.utils.data.distributed import DistributedSampler
+from torch.nn.parallel import DistributedDataParallel as DDP
+import torchvision
+from torch.utils.tensorboard import SummaryWriter
+
+from core.lr_scheduler import MultiStepRestartLR, CosineAnnealingRestartLR
+from core.loss import AdversarialLoss, PerceptualLoss, LPIPSLoss
+from core.dataset import TrainDataset
+
+from model.modules.flow_comp_raft import RAFT_bi, FlowLoss, EdgeLoss
+from model.recurrent_flow_completion import RecurrentFlowCompleteNet
+
+from RAFT.utils.flow_viz_pt import flow_to_image
+
+
+class Trainer:
+ def __init__(self, config):
+ self.config = config
+ self.epoch = 0
+ self.iteration = 0
+ self.num_local_frames = config['train_data_loader']['num_local_frames']
+ self.num_ref_frames = config['train_data_loader']['num_ref_frames']
+
+ # setup data set and data loader
+ self.train_dataset = TrainDataset(config['train_data_loader'])
+
+ self.train_sampler = None
+ self.train_args = config['trainer']
+ if config['distributed']:
+ self.train_sampler = DistributedSampler(
+ self.train_dataset,
+ num_replicas=config['world_size'],
+ rank=config['global_rank'])
+
+ dataloader_args = dict(
+ dataset=self.train_dataset,
+ batch_size=self.train_args['batch_size'] // config['world_size'],
+ shuffle=(self.train_sampler is None),
+ num_workers=self.train_args['num_workers'],
+ sampler=self.train_sampler,
+ drop_last=True)
+
+ self.train_loader = PrefetchDataLoader(self.train_args['num_prefetch_queue'], **dataloader_args)
+ self.prefetcher = CPUPrefetcher(self.train_loader)
+
+ # set loss functions
+ self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS'])
+ self.adversarial_loss = self.adversarial_loss.to(self.config['device'])
+ self.l1_loss = nn.L1Loss()
+ # self.perc_loss = PerceptualLoss(
+ # layer_weights={'conv3_4': 0.25, 'conv4_4': 0.25, 'conv5_4': 0.5},
+ # use_input_norm=True,
+ # range_norm=True,
+ # criterion='l1'
+ # ).to(self.config['device'])
+
+ if self.config['losses']['perceptual_weight'] > 0:
+ self.perc_loss = LPIPSLoss(use_input_norm=True, range_norm=True).to(self.config['device'])
+
+ # self.flow_comp_loss = FlowCompletionLoss().to(self.config['device'])
+ # self.flow_comp_loss = FlowCompletionLoss(self.config['device'])
+
+ # set raft
+ self.fix_raft = RAFT_bi(device = self.config['device'])
+ self.fix_flow_complete = RecurrentFlowCompleteNet('weights/recurrent_flow_completion.pth')
+ for p in self.fix_flow_complete.parameters():
+ p.requires_grad = False
+ self.fix_flow_complete.to(self.config['device'])
+ self.fix_flow_complete.eval()
+
+ # self.flow_loss = FlowLoss()
+
+ # setup models including generator and discriminator
+ net = importlib.import_module('model.' + config['model']['net'])
+ self.netG = net.InpaintGenerator()
+ # print(self.netG)
+ self.netG = self.netG.to(self.config['device'])
+ if not self.config['model'].get('no_dis', False):
+ if self.config['model'].get('dis_2d', False):
+ self.netD = net.Discriminator_2D(
+ in_channels=3,
+ use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge')
+ else:
+ self.netD = net.Discriminator(
+ in_channels=3,
+ use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge')
+ self.netD = self.netD.to(self.config['device'])
+
+ self.interp_mode = self.config['model']['interp_mode']
+ # setup optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+ self.load()
+
+ if config['distributed']:
+ self.netG = DDP(self.netG,
+ device_ids=[self.config['local_rank']],
+ output_device=self.config['local_rank'],
+ broadcast_buffers=True,
+ find_unused_parameters=True)
+ if not self.config['model']['no_dis']:
+ self.netD = DDP(self.netD,
+ device_ids=[self.config['local_rank']],
+ output_device=self.config['local_rank'],
+ broadcast_buffers=True,
+ find_unused_parameters=False)
+
+ # set summary writer
+ self.dis_writer = None
+ self.gen_writer = None
+ self.summary = {}
+ if self.config['global_rank'] == 0 or (not config['distributed']):
+ if not self.config['model']['no_dis']:
+ self.dis_writer = SummaryWriter(
+ os.path.join(config['save_dir'], 'dis'))
+ self.gen_writer = SummaryWriter(
+ os.path.join(config['save_dir'], 'gen'))
+
+ def setup_optimizers(self):
+ """Set up optimizers."""
+ backbone_params = []
+ for name, param in self.netG.named_parameters():
+ if param.requires_grad:
+ backbone_params.append(param)
+ else:
+ print(f'Params {name} will not be optimized.')
+
+ optim_params = [
+ {
+ 'params': backbone_params,
+ 'lr': self.config['trainer']['lr']
+ },
+ ]
+
+ self.optimG = torch.optim.Adam(optim_params,
+ betas=(self.config['trainer']['beta1'],
+ self.config['trainer']['beta2']))
+
+ if not self.config['model']['no_dis']:
+ self.optimD = torch.optim.Adam(
+ self.netD.parameters(),
+ lr=self.config['trainer']['lr'],
+ betas=(self.config['trainer']['beta1'],
+ self.config['trainer']['beta2']))
+
+ def setup_schedulers(self):
+ """Set up schedulers."""
+ scheduler_opt = self.config['trainer']['scheduler']
+ scheduler_type = scheduler_opt.pop('type')
+
+ if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
+ self.scheG = MultiStepRestartLR(
+ self.optimG,
+ milestones=scheduler_opt['milestones'],
+ gamma=scheduler_opt['gamma'])
+ if not self.config['model']['no_dis']:
+ self.scheD = MultiStepRestartLR(
+ self.optimD,
+ milestones=scheduler_opt['milestones'],
+ gamma=scheduler_opt['gamma'])
+ elif scheduler_type == 'CosineAnnealingRestartLR':
+ self.scheG = CosineAnnealingRestartLR(
+ self.optimG,
+ periods=scheduler_opt['periods'],
+ restart_weights=scheduler_opt['restart_weights'],
+ eta_min=scheduler_opt['eta_min'])
+ if not self.config['model']['no_dis']:
+ self.scheD = CosineAnnealingRestartLR(
+ self.optimD,
+ periods=scheduler_opt['periods'],
+ restart_weights=scheduler_opt['restart_weights'],
+ eta_min=scheduler_opt['eta_min'])
+ else:
+ raise NotImplementedError(
+ f'Scheduler {scheduler_type} is not implemented yet.')
+
+ def update_learning_rate(self):
+ """Update learning rate."""
+ self.scheG.step()
+ if not self.config['model']['no_dis']:
+ self.scheD.step()
+
+ def get_lr(self):
+ """Get current learning rate."""
+ return self.optimG.param_groups[0]['lr']
+
+ def add_summary(self, writer, name, val):
+ """Add tensorboard summary."""
+ if name not in self.summary:
+ self.summary[name] = 0
+ self.summary[name] += val
+ n = self.train_args['log_freq']
+ if writer is not None and self.iteration % n == 0:
+ writer.add_scalar(name, self.summary[name] / n, self.iteration)
+ self.summary[name] = 0
+
+ def load(self):
+ """Load netG (and netD)."""
+ # get the latest checkpoint
+ model_path = self.config['save_dir']
+ # TODO: add resume name
+ if os.path.isfile(os.path.join(model_path, 'latest.ckpt')):
+ latest_epoch = open(os.path.join(model_path, 'latest.ckpt'),
+ 'r').read().splitlines()[-1]
+ else:
+ ckpts = [
+ os.path.basename(i).split('.pth')[0]
+ for i in glob.glob(os.path.join(model_path, '*.pth'))
+ ]
+ ckpts.sort()
+ latest_epoch = ckpts[-1][4:] if len(ckpts) > 0 else None
+
+ if latest_epoch is not None:
+ gen_path = os.path.join(model_path,
+ f'gen_{int(latest_epoch):06d}.pth')
+ dis_path = os.path.join(model_path,
+ f'dis_{int(latest_epoch):06d}.pth')
+ opt_path = os.path.join(model_path,
+ f'opt_{int(latest_epoch):06d}.pth')
+
+ if self.config['global_rank'] == 0:
+ print(f'Loading model from {gen_path}...')
+ dataG = torch.load(gen_path, map_location=self.config['device'])
+ self.netG.load_state_dict(dataG)
+ if not self.config['model']['no_dis'] and self.config['model']['load_d']:
+ dataD = torch.load(dis_path, map_location=self.config['device'])
+ self.netD.load_state_dict(dataD)
+
+ data_opt = torch.load(opt_path, map_location=self.config['device'])
+ self.optimG.load_state_dict(data_opt['optimG'])
+ # self.scheG.load_state_dict(data_opt['scheG'])
+ if not self.config['model']['no_dis'] and self.config['model']['load_d']:
+ self.optimD.load_state_dict(data_opt['optimD'])
+ # self.scheD.load_state_dict(data_opt['scheD'])
+ self.epoch = data_opt['epoch']
+ self.iteration = data_opt['iteration']
+ else:
+ gen_path = self.config['trainer'].get('gen_path', None)
+ dis_path = self.config['trainer'].get('dis_path', None)
+ opt_path = self.config['trainer'].get('opt_path', None)
+ if gen_path is not None:
+ if self.config['global_rank'] == 0:
+ print(f'Loading Gen-Net from {gen_path}...')
+ dataG = torch.load(gen_path, map_location=self.config['device'])
+ self.netG.load_state_dict(dataG)
+
+ if dis_path is not None and not self.config['model']['no_dis'] and self.config['model']['load_d']:
+ if self.config['global_rank'] == 0:
+ print(f'Loading Dis-Net from {dis_path}...')
+ dataD = torch.load(dis_path, map_location=self.config['device'])
+ self.netD.load_state_dict(dataD)
+ if opt_path is not None:
+ data_opt = torch.load(opt_path, map_location=self.config['device'])
+ self.optimG.load_state_dict(data_opt['optimG'])
+ self.scheG.load_state_dict(data_opt['scheG'])
+ if not self.config['model']['no_dis'] and self.config['model']['load_d']:
+ self.optimD.load_state_dict(data_opt['optimD'])
+ self.scheD.load_state_dict(data_opt['scheD'])
+ else:
+ if self.config['global_rank'] == 0:
+ print('Warnning: There is no trained model found.'
+ 'An initialized model will be used.')
+
+ def save(self, it):
+ """Save parameters every eval_epoch"""
+ if self.config['global_rank'] == 0:
+ # configure path
+ gen_path = os.path.join(self.config['save_dir'],
+ f'gen_{it:06d}.pth')
+ dis_path = os.path.join(self.config['save_dir'],
+ f'dis_{it:06d}.pth')
+ opt_path = os.path.join(self.config['save_dir'],
+ f'opt_{it:06d}.pth')
+ print(f'\nsaving model to {gen_path} ...')
+
+ # remove .module for saving
+ if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP):
+ netG = self.netG.module
+ if not self.config['model']['no_dis']:
+ netD = self.netD.module
+ else:
+ netG = self.netG
+ if not self.config['model']['no_dis']:
+ netD = self.netD
+
+ # save checkpoints
+ torch.save(netG.state_dict(), gen_path)
+ if not self.config['model']['no_dis']:
+ torch.save(netD.state_dict(), dis_path)
+ torch.save(
+ {
+ 'epoch': self.epoch,
+ 'iteration': self.iteration,
+ 'optimG': self.optimG.state_dict(),
+ 'optimD': self.optimD.state_dict(),
+ 'scheG': self.scheG.state_dict(),
+ 'scheD': self.scheD.state_dict()
+ }, opt_path)
+ else:
+ torch.save(
+ {
+ 'epoch': self.epoch,
+ 'iteration': self.iteration,
+ 'optimG': self.optimG.state_dict(),
+ 'scheG': self.scheG.state_dict()
+ }, opt_path)
+
+ latest_path = os.path.join(self.config['save_dir'], 'latest.ckpt')
+ os.system(f"echo {it:06d} > {latest_path}")
+
+ def train(self):
+ """training entry"""
+ pbar = range(int(self.train_args['iterations']))
+ if self.config['global_rank'] == 0:
+ pbar = tqdm(pbar,
+ initial=self.iteration,
+ dynamic_ncols=True,
+ smoothing=0.01)
+
+ os.makedirs('logs', exist_ok=True)
+
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s %(filename)s[line:%(lineno)d]"
+ "%(levelname)s %(message)s",
+ datefmt="%a, %d %b %Y %H:%M:%S",
+ filename=f"logs/{self.config['save_dir'].split('/')[-1]}.log",
+ filemode='w')
+
+ while True:
+ self.epoch += 1
+ self.prefetcher.reset()
+ if self.config['distributed']:
+ self.train_sampler.set_epoch(self.epoch)
+ self._train_epoch(pbar)
+ if self.iteration > self.train_args['iterations']:
+ break
+ print('\nEnd training....')
+
+ def _train_epoch(self, pbar):
+ """Process input and calculate loss every training epoch"""
+ device = self.config['device']
+ train_data = self.prefetcher.next()
+ while train_data is not None:
+ self.iteration += 1
+ frames, masks, flows_f, flows_b, _ = train_data
+ frames, masks = frames.to(device), masks.to(device).float()
+ l_t = self.num_local_frames
+ b, t, c, h, w = frames.size()
+ gt_local_frames = frames[:, :l_t, ...]
+ local_masks = masks[:, :l_t, ...].contiguous()
+
+ masked_frames = frames * (1 - masks)
+ masked_local_frames = masked_frames[:, :l_t, ...]
+ # get gt optical flow
+ if flows_f[0] == 'None' or flows_b[0] == 'None':
+ gt_flows_bi = self.fix_raft(gt_local_frames)
+ else:
+ gt_flows_bi = (flows_f.to(device), flows_b.to(device))
+
+ # ---- complete flow ----
+ pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, local_masks)
+ pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, local_masks)
+ # pred_flows_bi = gt_flows_bi
+
+ # ---- image propagation ----
+ prop_imgs, updated_local_masks = self.netG.module.img_propagation(masked_local_frames, pred_flows_bi, local_masks, interpolation=self.interp_mode)
+ updated_masks = masks.clone()
+ updated_masks[:, :l_t, ...] = updated_local_masks.view(b, l_t, 1, h, w)
+ updated_frames = masked_frames.clone()
+ prop_local_frames = gt_local_frames * (1-local_masks) + prop_imgs.view(b, l_t, 3, h, w) * local_masks # merge
+ updated_frames[:, :l_t, ...] = prop_local_frames
+
+ # ---- feature propagation + Transformer ----
+ pred_imgs = self.netG(updated_frames, pred_flows_bi, masks, updated_masks, l_t)
+ pred_imgs = pred_imgs.view(b, -1, c, h, w)
+
+ # get the local frames
+ pred_local_frames = pred_imgs[:, :l_t, ...]
+ comp_local_frames = gt_local_frames * (1. - local_masks) + pred_local_frames * local_masks
+ comp_imgs = frames * (1. - masks) + pred_imgs * masks
+
+ gen_loss = 0
+ dis_loss = 0
+ # optimize net_g
+ if not self.config['model']['no_dis']:
+ for p in self.netD.parameters():
+ p.requires_grad = False
+
+ self.optimG.zero_grad()
+
+ # generator l1 loss
+ hole_loss = self.l1_loss(pred_imgs * masks, frames * masks)
+ hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight']
+ gen_loss += hole_loss
+ self.add_summary(self.gen_writer, 'loss/hole_loss', hole_loss.item())
+
+ valid_loss = self.l1_loss(pred_imgs * (1 - masks), frames * (1 - masks))
+ valid_loss = valid_loss / torch.mean(1-masks) * self.config['losses']['valid_weight']
+ gen_loss += valid_loss
+ self.add_summary(self.gen_writer, 'loss/valid_loss', valid_loss.item())
+
+ # perceptual loss
+ if self.config['losses']['perceptual_weight'] > 0:
+ perc_loss = self.perc_loss(pred_imgs.view(-1,3,h,w), frames.view(-1,3,h,w))[0] * self.config['losses']['perceptual_weight']
+ gen_loss += perc_loss
+ self.add_summary(self.gen_writer, 'loss/perc_loss', perc_loss.item())
+
+ # gan loss
+ if not self.config['model']['no_dis']:
+ # generator adversarial loss
+ gen_clip = self.netD(comp_imgs)
+ gan_loss = self.adversarial_loss(gen_clip, True, False)
+ gan_loss = gan_loss * self.config['losses']['adversarial_weight']
+ gen_loss += gan_loss
+ self.add_summary(self.gen_writer, 'loss/gan_loss', gan_loss.item())
+ gen_loss.backward()
+ self.optimG.step()
+
+ if not self.config['model']['no_dis']:
+ # optimize net_d
+ for p in self.netD.parameters():
+ p.requires_grad = True
+ self.optimD.zero_grad()
+
+ # discriminator adversarial loss
+ real_clip = self.netD(frames)
+ fake_clip = self.netD(comp_imgs.detach())
+ dis_real_loss = self.adversarial_loss(real_clip, True, True)
+ dis_fake_loss = self.adversarial_loss(fake_clip, False, True)
+ dis_loss += (dis_real_loss + dis_fake_loss) / 2
+ self.add_summary(self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item())
+ self.add_summary(self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item())
+ dis_loss.backward()
+ self.optimD.step()
+
+ self.update_learning_rate()
+
+ # write image to tensorboard
+ if self.iteration % 200 == 0:
+ # img to cpu
+ t = 0
+ gt_local_frames_cpu = ((gt_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu()
+ masked_local_frames = ((masked_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu()
+ prop_local_frames_cpu = ((prop_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu()
+ pred_local_frames_cpu = ((pred_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu()
+ img_results = torch.cat([masked_local_frames[0][t], gt_local_frames_cpu[0][t],
+ prop_local_frames_cpu[0][t], pred_local_frames_cpu[0][t]], 1)
+ img_results = torchvision.utils.make_grid(img_results, nrow=1, normalize=True)
+ if self.gen_writer is not None:
+ self.gen_writer.add_image(f'img/img:inp-gt-res-{t}', img_results, self.iteration)
+
+ t = 5
+ if masked_local_frames.shape[1] > 5:
+ img_results = torch.cat([masked_local_frames[0][t], gt_local_frames_cpu[0][t],
+ prop_local_frames_cpu[0][t], pred_local_frames_cpu[0][t]], 1)
+ img_results = torchvision.utils.make_grid(img_results, nrow=1, normalize=True)
+ if self.gen_writer is not None:
+ self.gen_writer.add_image(f'img/img:inp-gt-res-{t}', img_results, self.iteration)
+
+ # flow to cpu
+ gt_flows_forward_cpu = flow_to_image(gt_flows_bi[0][0]).cpu()
+ masked_flows_forward_cpu = (gt_flows_forward_cpu[0] * (1-local_masks[0][0].cpu())).to(gt_flows_forward_cpu)
+ pred_flows_forward_cpu = flow_to_image(pred_flows_bi[0][0]).cpu()
+
+ flow_results = torch.cat([gt_flows_forward_cpu[0], masked_flows_forward_cpu, pred_flows_forward_cpu[0]], 1)
+ if self.gen_writer is not None:
+ self.gen_writer.add_image('img/flow:gt-pred', flow_results, self.iteration)
+
+ # console logs
+ if self.config['global_rank'] == 0:
+ pbar.update(1)
+ if not self.config['model']['no_dis']:
+ pbar.set_description((f"d: {dis_loss.item():.3f}; "
+ f"hole: {hole_loss.item():.3f}; "
+ f"valid: {valid_loss.item():.3f}"))
+ else:
+ pbar.set_description((f"hole: {hole_loss.item():.3f}; "
+ f"valid: {valid_loss.item():.3f}"))
+
+ if self.iteration % self.train_args['log_freq'] == 0:
+ if not self.config['model']['no_dis']:
+ logging.info(f"[Iter {self.iteration}] "
+ f"d: {dis_loss.item():.4f}; "
+ f"hole: {hole_loss.item():.4f}; "
+ f"valid: {valid_loss.item():.4f}")
+ else:
+ logging.info(f"[Iter {self.iteration}] "
+ f"hole: {hole_loss.item():.4f}; "
+ f"valid: {valid_loss.item():.4f}")
+
+ # saving models
+ if self.iteration % self.train_args['save_freq'] == 0:
+ self.save(int(self.iteration))
+
+ if self.iteration > self.train_args['iterations']:
+ break
+
+ train_data = self.prefetcher.next()
\ No newline at end of file
diff --git a/propainter/core/trainer_flow_w_edge.py b/propainter/core/trainer_flow_w_edge.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4eba04c8a5fa56bce3e335e6036bc0e0a1e848a
--- /dev/null
+++ b/propainter/core/trainer_flow_w_edge.py
@@ -0,0 +1,380 @@
+import os
+import glob
+import logging
+import importlib
+from tqdm import tqdm
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from core.prefetch_dataloader import PrefetchDataLoader, CPUPrefetcher
+from torch.utils.data.distributed import DistributedSampler
+from torch.nn.parallel import DistributedDataParallel as DDP
+
+from torch.utils.tensorboard import SummaryWriter
+
+from core.lr_scheduler import MultiStepRestartLR, CosineAnnealingRestartLR
+from core.dataset import TrainDataset
+
+from model.modules.flow_comp_raft import RAFT_bi, FlowLoss, EdgeLoss
+
+# from skimage.feature import canny
+from model.canny.canny_filter import Canny
+from RAFT.utils.flow_viz_pt import flow_to_image
+
+
+class Trainer:
+ def __init__(self, config):
+ self.config = config
+ self.epoch = 0
+ self.iteration = 0
+ self.num_local_frames = config['train_data_loader']['num_local_frames']
+ self.num_ref_frames = config['train_data_loader']['num_ref_frames']
+
+ # setup data set and data loader
+ self.train_dataset = TrainDataset(config['train_data_loader'])
+
+ self.train_sampler = None
+ self.train_args = config['trainer']
+ if config['distributed']:
+ self.train_sampler = DistributedSampler(
+ self.train_dataset,
+ num_replicas=config['world_size'],
+ rank=config['global_rank'])
+
+ dataloader_args = dict(
+ dataset=self.train_dataset,
+ batch_size=self.train_args['batch_size'] // config['world_size'],
+ shuffle=(self.train_sampler is None),
+ num_workers=self.train_args['num_workers'],
+ sampler=self.train_sampler,
+ drop_last=True)
+
+ self.train_loader = PrefetchDataLoader(self.train_args['num_prefetch_queue'], **dataloader_args)
+ self.prefetcher = CPUPrefetcher(self.train_loader)
+
+ # set raft
+ self.fix_raft = RAFT_bi(device = self.config['device'])
+ self.flow_loss = FlowLoss()
+ self.edge_loss = EdgeLoss()
+ self.canny = Canny(sigma=(2,2), low_threshold=0.1, high_threshold=0.2)
+
+ # setup models including generator and discriminator
+ net = importlib.import_module('model.' + config['model']['net'])
+ self.netG = net.RecurrentFlowCompleteNet()
+ # print(self.netG)
+ self.netG = self.netG.to(self.config['device'])
+
+ # setup optimizers and schedulers
+ self.setup_optimizers()
+ self.setup_schedulers()
+ self.load()
+
+ if config['distributed']:
+ self.netG = DDP(self.netG,
+ device_ids=[self.config['local_rank']],
+ output_device=self.config['local_rank'],
+ broadcast_buffers=True,
+ find_unused_parameters=True)
+
+ # set summary writer
+ self.dis_writer = None
+ self.gen_writer = None
+ self.summary = {}
+ if self.config['global_rank'] == 0 or (not config['distributed']):
+ self.gen_writer = SummaryWriter(
+ os.path.join(config['save_dir'], 'gen'))
+
+ def setup_optimizers(self):
+ """Set up optimizers."""
+ backbone_params = []
+ for name, param in self.netG.named_parameters():
+ if param.requires_grad:
+ backbone_params.append(param)
+ else:
+ print(f'Params {name} will not be optimized.')
+
+ optim_params = [
+ {
+ 'params': backbone_params,
+ 'lr': self.config['trainer']['lr']
+ },
+ ]
+
+ self.optimG = torch.optim.Adam(optim_params,
+ betas=(self.config['trainer']['beta1'],
+ self.config['trainer']['beta2']))
+
+
+ def setup_schedulers(self):
+ """Set up schedulers."""
+ scheduler_opt = self.config['trainer']['scheduler']
+ scheduler_type = scheduler_opt.pop('type')
+
+ if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
+ self.scheG = MultiStepRestartLR(
+ self.optimG,
+ milestones=scheduler_opt['milestones'],
+ gamma=scheduler_opt['gamma'])
+ elif scheduler_type == 'CosineAnnealingRestartLR':
+ self.scheG = CosineAnnealingRestartLR(
+ self.optimG,
+ periods=scheduler_opt['periods'],
+ restart_weights=scheduler_opt['restart_weights'])
+ else:
+ raise NotImplementedError(
+ f'Scheduler {scheduler_type} is not implemented yet.')
+
+ def update_learning_rate(self):
+ """Update learning rate."""
+ self.scheG.step()
+
+ def get_lr(self):
+ """Get current learning rate."""
+ return self.optimG.param_groups[0]['lr']
+
+ def add_summary(self, writer, name, val):
+ """Add tensorboard summary."""
+ if name not in self.summary:
+ self.summary[name] = 0
+ self.summary[name] += val
+ n = self.train_args['log_freq']
+ if writer is not None and self.iteration % n == 0:
+ writer.add_scalar(name, self.summary[name] / n, self.iteration)
+ self.summary[name] = 0
+
+ def load(self):
+ """Load netG."""
+ # get the latest checkpoint
+ model_path = self.config['save_dir']
+ if os.path.isfile(os.path.join(model_path, 'latest.ckpt')):
+ latest_epoch = open(os.path.join(model_path, 'latest.ckpt'),
+ 'r').read().splitlines()[-1]
+ else:
+ ckpts = [
+ os.path.basename(i).split('.pth')[0]
+ for i in glob.glob(os.path.join(model_path, '*.pth'))
+ ]
+ ckpts.sort()
+ latest_epoch = ckpts[-1][4:] if len(ckpts) > 0 else None
+
+ if latest_epoch is not None:
+ gen_path = os.path.join(model_path, f'gen_{int(latest_epoch):06d}.pth')
+ opt_path = os.path.join(model_path,f'opt_{int(latest_epoch):06d}.pth')
+
+ if self.config['global_rank'] == 0:
+ print(f'Loading model from {gen_path}...')
+ dataG = torch.load(gen_path, map_location=self.config['device'])
+ self.netG.load_state_dict(dataG)
+
+
+ data_opt = torch.load(opt_path, map_location=self.config['device'])
+ self.optimG.load_state_dict(data_opt['optimG'])
+ self.scheG.load_state_dict(data_opt['scheG'])
+
+ self.epoch = data_opt['epoch']
+ self.iteration = data_opt['iteration']
+
+ else:
+ if self.config['global_rank'] == 0:
+ print('Warnning: There is no trained model found.'
+ 'An initialized model will be used.')
+
+ def save(self, it):
+ """Save parameters every eval_epoch"""
+ if self.config['global_rank'] == 0:
+ # configure path
+ gen_path = os.path.join(self.config['save_dir'],
+ f'gen_{it:06d}.pth')
+ opt_path = os.path.join(self.config['save_dir'],
+ f'opt_{it:06d}.pth')
+ print(f'\nsaving model to {gen_path} ...')
+
+ # remove .module for saving
+ if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP):
+ netG = self.netG.module
+ else:
+ netG = self.netG
+
+ # save checkpoints
+ torch.save(netG.state_dict(), gen_path)
+ torch.save(
+ {
+ 'epoch': self.epoch,
+ 'iteration': self.iteration,
+ 'optimG': self.optimG.state_dict(),
+ 'scheG': self.scheG.state_dict()
+ }, opt_path)
+
+ latest_path = os.path.join(self.config['save_dir'], 'latest.ckpt')
+ os.system(f"echo {it:06d} > {latest_path}")
+
+ def train(self):
+ """training entry"""
+ pbar = range(int(self.train_args['iterations']))
+ if self.config['global_rank'] == 0:
+ pbar = tqdm(pbar,
+ initial=self.iteration,
+ dynamic_ncols=True,
+ smoothing=0.01)
+
+ os.makedirs('logs', exist_ok=True)
+
+ logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s %(filename)s[line:%(lineno)d]"
+ "%(levelname)s %(message)s",
+ datefmt="%a, %d %b %Y %H:%M:%S",
+ filename=f"logs/{self.config['save_dir'].split('/')[-1]}.log",
+ filemode='w')
+
+ while True:
+ self.epoch += 1
+ self.prefetcher.reset()
+ if self.config['distributed']:
+ self.train_sampler.set_epoch(self.epoch)
+ self._train_epoch(pbar)
+ if self.iteration > self.train_args['iterations']:
+ break
+ print('\nEnd training....')
+
+ # def get_edges(self, flows): # fgvc
+ # # (b, t, 2, H, W)
+ # b, t, _, h, w = flows.shape
+ # flows = flows.view(-1, 2, h, w)
+ # flows_list = flows.permute(0, 2, 3, 1).cpu().numpy()
+ # edges = []
+ # for f in list(flows_list):
+ # flows_gray = (f[:, :, 0] ** 2 + f[:, :, 1] ** 2) ** 0.5
+ # if flows_gray.max() < 1:
+ # flows_gray = flows_gray*0
+ # else:
+ # flows_gray = flows_gray / flows_gray.max()
+
+ # edge = canny(flows_gray, sigma=2, low_threshold=0.1, high_threshold=0.2) # fgvc
+ # edge = torch.from_numpy(edge).view(1, 1, h, w).float()
+ # edges.append(edge)
+ # edges = torch.stack(edges, dim=0).to(self.config['device'])
+ # edges = edges.view(b, t, 1, h, w)
+ # return edges
+
+ def get_edges(self, flows):
+ # (b, t, 2, H, W)
+ b, t, _, h, w = flows.shape
+ flows = flows.view(-1, 2, h, w)
+ flows_gray = (flows[:, 0, None] ** 2 + flows[:, 1, None] ** 2) ** 0.5
+ if flows_gray.max() < 1:
+ flows_gray = flows_gray*0
+ else:
+ flows_gray = flows_gray / flows_gray.max()
+
+ magnitude, edges = self.canny(flows_gray.float())
+ edges = edges.view(b, t, 1, h, w)
+ return edges
+
+ def _train_epoch(self, pbar):
+ """Process input and calculate loss every training epoch"""
+ device = self.config['device']
+ train_data = self.prefetcher.next()
+ while train_data is not None:
+ self.iteration += 1
+ frames, masks, flows_f, flows_b, _ = train_data
+ frames, masks = frames.to(device), masks.to(device)
+ masks = masks.float()
+
+ l_t = self.num_local_frames
+ b, t, c, h, w = frames.size()
+ gt_local_frames = frames[:, :l_t, ...]
+ local_masks = masks[:, :l_t, ...].contiguous()
+
+ # get gt optical flow
+ if flows_f[0] == 'None' or flows_b[0] == 'None':
+ gt_flows_bi = self.fix_raft(gt_local_frames)
+ else:
+ gt_flows_bi = (flows_f.to(device), flows_b.to(device))
+
+ # get gt edge
+ gt_edges_forward = self.get_edges(gt_flows_bi[0])
+ gt_edges_backward = self.get_edges(gt_flows_bi[1])
+ gt_edges_bi = [gt_edges_forward, gt_edges_backward]
+
+ # complete flow
+ pred_flows_bi, pred_edges_bi = self.netG.module.forward_bidirect_flow(gt_flows_bi, local_masks)
+
+ # optimize net_g
+ self.optimG.zero_grad()
+
+ # compulte flow_loss
+ flow_loss, warp_loss = self.flow_loss(pred_flows_bi, gt_flows_bi, local_masks, gt_local_frames)
+ flow_loss = flow_loss * self.config['losses']['flow_weight']
+ warp_loss = warp_loss * 0.01
+ self.add_summary(self.gen_writer, 'loss/flow_loss', flow_loss.item())
+ self.add_summary(self.gen_writer, 'loss/warp_loss', warp_loss.item())
+
+ # compute edge loss
+ edge_loss = self.edge_loss(pred_edges_bi, gt_edges_bi, local_masks)
+ edge_loss = edge_loss*1.0
+ self.add_summary(self.gen_writer, 'loss/edge_loss', edge_loss.item())
+
+ loss = flow_loss + warp_loss + edge_loss
+ loss.backward()
+ self.optimG.step()
+ self.update_learning_rate()
+
+ # write image to tensorboard
+ # if self.iteration % 200 == 0:
+ if self.iteration % 200 == 0 and self.gen_writer is not None:
+ t = 5
+ # forward to cpu
+ gt_flows_forward_cpu = flow_to_image(gt_flows_bi[0][0]).cpu()
+ masked_flows_forward_cpu = (gt_flows_forward_cpu[t] * (1-local_masks[0][t].cpu())).to(gt_flows_forward_cpu)
+ pred_flows_forward_cpu = flow_to_image(pred_flows_bi[0][0]).cpu()
+
+ flow_results = torch.cat([gt_flows_forward_cpu[t], masked_flows_forward_cpu, pred_flows_forward_cpu[t]], 1)
+ self.gen_writer.add_image('img/flow-f:gt-pred', flow_results, self.iteration)
+
+ # backward to cpu
+ gt_flows_backward_cpu = flow_to_image(gt_flows_bi[1][0]).cpu()
+ masked_flows_backward_cpu = (gt_flows_backward_cpu[t] * (1-local_masks[0][t+1].cpu())).to(gt_flows_backward_cpu)
+ pred_flows_backward_cpu = flow_to_image(pred_flows_bi[1][0]).cpu()
+
+ flow_results = torch.cat([gt_flows_backward_cpu[t], masked_flows_backward_cpu, pred_flows_backward_cpu[t]], 1)
+ self.gen_writer.add_image('img/flow-b:gt-pred', flow_results, self.iteration)
+
+ # TODO: show edge
+ # forward
+ gt_edges_forward_cpu = gt_edges_bi[0][0].cpu()
+ masked_edges_forward_cpu = (gt_edges_forward_cpu[t] * (1-local_masks[0][t].cpu())).to(gt_edges_forward_cpu)
+ pred_edges_forward_cpu = pred_edges_bi[0][0].cpu()
+
+ edge_results = torch.cat([gt_edges_forward_cpu[t], masked_edges_forward_cpu, pred_edges_forward_cpu[t]], 1)
+ self.gen_writer.add_image('img/edge-f:gt-pred', edge_results, self.iteration)
+ # backward
+ gt_edges_backward_cpu = gt_edges_bi[1][0].cpu()
+ masked_edges_backward_cpu = (gt_edges_backward_cpu[t] * (1-local_masks[0][t+1].cpu())).to(gt_edges_backward_cpu)
+ pred_edges_backward_cpu = pred_edges_bi[1][0].cpu()
+
+ edge_results = torch.cat([gt_edges_backward_cpu[t], masked_edges_backward_cpu, pred_edges_backward_cpu[t]], 1)
+ self.gen_writer.add_image('img/edge-b:gt-pred', edge_results, self.iteration)
+
+ # console logs
+ if self.config['global_rank'] == 0:
+ pbar.update(1)
+ pbar.set_description((f"flow: {flow_loss.item():.3f}; "
+ f"warp: {warp_loss.item():.3f}; "
+ f"edge: {edge_loss.item():.3f}; "
+ f"lr: {self.get_lr()}"))
+
+ if self.iteration % self.train_args['log_freq'] == 0:
+ logging.info(f"[Iter {self.iteration}] "
+ f"flow: {flow_loss.item():.4f}; "
+ f"warp: {warp_loss.item():.4f}")
+
+ # saving models
+ if self.iteration % self.train_args['save_freq'] == 0:
+ self.save(int(self.iteration))
+
+ if self.iteration > self.train_args['iterations']:
+ break
+
+ train_data = self.prefetcher.next()
\ No newline at end of file
diff --git a/propainter/core/utils.py b/propainter/core/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..37dccb2d26e6916aacbd530ab03726a7c54f8ec8
--- /dev/null
+++ b/propainter/core/utils.py
@@ -0,0 +1,371 @@
+import os
+import io
+import cv2
+import random
+import numpy as np
+from PIL import Image, ImageOps
+import zipfile
+import math
+
+import torch
+import matplotlib
+import matplotlib.patches as patches
+from matplotlib.path import Path
+from matplotlib import pyplot as plt
+from torchvision import transforms
+
+# matplotlib.use('agg')
+
+# ###########################################################################
+# Directory IO
+# ###########################################################################
+
+
+def read_dirnames_under_root(root_dir):
+ dirnames = [
+ name for i, name in enumerate(sorted(os.listdir(root_dir)))
+ if os.path.isdir(os.path.join(root_dir, name))
+ ]
+ print(f'Reading directories under {root_dir}, num: {len(dirnames)}')
+ return dirnames
+
+
+class TrainZipReader(object):
+ file_dict = dict()
+
+ def __init__(self):
+ super(TrainZipReader, self).__init__()
+
+ @staticmethod
+ def build_file_dict(path):
+ file_dict = TrainZipReader.file_dict
+ if path in file_dict:
+ return file_dict[path]
+ else:
+ file_handle = zipfile.ZipFile(path, 'r')
+ file_dict[path] = file_handle
+ return file_dict[path]
+
+ @staticmethod
+ def imread(path, idx):
+ zfile = TrainZipReader.build_file_dict(path)
+ filelist = zfile.namelist()
+ filelist.sort()
+ data = zfile.read(filelist[idx])
+ #
+ im = Image.open(io.BytesIO(data))
+ return im
+
+
+class TestZipReader(object):
+ file_dict = dict()
+
+ def __init__(self):
+ super(TestZipReader, self).__init__()
+
+ @staticmethod
+ def build_file_dict(path):
+ file_dict = TestZipReader.file_dict
+ if path in file_dict:
+ return file_dict[path]
+ else:
+ file_handle = zipfile.ZipFile(path, 'r')
+ file_dict[path] = file_handle
+ return file_dict[path]
+
+ @staticmethod
+ def imread(path, idx):
+ zfile = TestZipReader.build_file_dict(path)
+ filelist = zfile.namelist()
+ filelist.sort()
+ data = zfile.read(filelist[idx])
+ file_bytes = np.asarray(bytearray(data), dtype=np.uint8)
+ im = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR)
+ im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB))
+ # im = Image.open(io.BytesIO(data))
+ return im
+
+
+# ###########################################################################
+# Data augmentation
+# ###########################################################################
+
+
+def to_tensors():
+ return transforms.Compose([Stack(), ToTorchFormatTensor()])
+
+
+class GroupRandomHorizontalFlowFlip(object):
+ """Randomly horizontally flips the given PIL.Image with a probability of 0.5
+ """
+ def __call__(self, img_group, flowF_group, flowB_group):
+ v = random.random()
+ if v < 0.5:
+ ret_img = [
+ img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group
+ ]
+ ret_flowF = [ff[:, ::-1] * [-1.0, 1.0] for ff in flowF_group]
+ ret_flowB = [fb[:, ::-1] * [-1.0, 1.0] for fb in flowB_group]
+ return ret_img, ret_flowF, ret_flowB
+ else:
+ return img_group, flowF_group, flowB_group
+
+
+class GroupRandomHorizontalFlip(object):
+ """Randomly horizontally flips the given PIL.Image with a probability of 0.5
+ """
+ def __call__(self, img_group, is_flow=False):
+ v = random.random()
+ if v < 0.5:
+ ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group]
+ if is_flow:
+ for i in range(0, len(ret), 2):
+ # invert flow pixel values when flipping
+ ret[i] = ImageOps.invert(ret[i])
+ return ret
+ else:
+ return img_group
+
+
+class Stack(object):
+ def __init__(self, roll=False):
+ self.roll = roll
+
+ def __call__(self, img_group):
+ mode = img_group[0].mode
+ if mode == '1':
+ img_group = [img.convert('L') for img in img_group]
+ mode = 'L'
+ if mode == 'L':
+ return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2)
+ elif mode == 'RGB':
+ if self.roll:
+ return np.stack([np.array(x)[:, :, ::-1] for x in img_group],
+ axis=2)
+ else:
+ return np.stack(img_group, axis=2)
+ else:
+ raise NotImplementedError(f"Image mode {mode}")
+
+
+class ToTorchFormatTensor(object):
+ """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
+ to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """
+ def __init__(self, div=True):
+ self.div = div
+
+ def __call__(self, pic):
+ if isinstance(pic, np.ndarray):
+ # numpy img: [L, C, H, W]
+ img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous()
+ else:
+ # handle PIL Image
+ img = torch.ByteTensor(torch.ByteStorage.from_buffer(
+ pic.tobytes()))
+ img = img.view(pic.size[1], pic.size[0], len(pic.mode))
+ # put it from HWC to CHW format
+ # yikes, this transpose takes 80% of the loading time/CPU
+ img = img.transpose(0, 1).transpose(0, 2).contiguous()
+ img = img.float().div(255) if self.div else img.float()
+ return img
+
+
+# ###########################################################################
+# Create masks with random shape
+# ###########################################################################
+
+
+def create_random_shape_with_random_motion(video_length,
+ imageHeight=240,
+ imageWidth=432):
+ # get a random shape
+ height = random.randint(imageHeight // 3, imageHeight - 1)
+ width = random.randint(imageWidth // 3, imageWidth - 1)
+ edge_num = random.randint(6, 8)
+ ratio = random.randint(6, 8) / 10
+
+ region = get_random_shape(edge_num=edge_num,
+ ratio=ratio,
+ height=height,
+ width=width)
+ region_width, region_height = region.size
+ # get random position
+ x, y = random.randint(0, imageHeight - region_height), random.randint(
+ 0, imageWidth - region_width)
+ velocity = get_random_velocity(max_speed=3)
+ m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
+ m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
+ masks = [m.convert('L')]
+ # return fixed masks
+ if random.uniform(0, 1) > 0.5:
+ return masks * video_length
+ # return moving masks
+ for _ in range(video_length - 1):
+ x, y, velocity = random_move_control_points(x,
+ y,
+ imageHeight,
+ imageWidth,
+ velocity,
+ region.size,
+ maxLineAcceleration=(3,
+ 0.5),
+ maxInitSpeed=3)
+ m = Image.fromarray(
+ np.zeros((imageHeight, imageWidth)).astype(np.uint8))
+ m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
+ masks.append(m.convert('L'))
+ return masks
+
+
+def create_random_shape_with_random_motion_zoom_rotation(video_length, zoomin=0.9, zoomout=1.1, rotmin=1, rotmax=10, imageHeight=240, imageWidth=432):
+ # get a random shape
+ assert zoomin < 1, "Zoom-in parameter must be smaller than 1"
+ assert zoomout > 1, "Zoom-out parameter must be larger than 1"
+ assert rotmin < rotmax, "Minimum value of rotation must be smaller than maximun value !"
+ height = random.randint(imageHeight//3, imageHeight-1)
+ width = random.randint(imageWidth//3, imageWidth-1)
+ edge_num = random.randint(6, 8)
+ ratio = random.randint(6, 8)/10
+ region = get_random_shape(
+ edge_num=edge_num, ratio=ratio, height=height, width=width)
+ region_width, region_height = region.size
+ # get random position
+ x, y = random.randint(
+ 0, imageHeight-region_height), random.randint(0, imageWidth-region_width)
+ velocity = get_random_velocity(max_speed=3)
+ m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8))
+ m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
+ masks = [m.convert('L')]
+ # return fixed masks
+ if random.uniform(0, 1) > 0.5:
+ return masks*video_length # -> directly copy all the base masks
+ # return moving masks
+ for _ in range(video_length-1):
+ x, y, velocity = random_move_control_points(
+ x, y, imageHeight, imageWidth, velocity, region.size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3)
+ m = Image.fromarray(
+ np.zeros((imageHeight, imageWidth)).astype(np.uint8))
+ ### add by kaidong, to simulate zoon-in, zoom-out and rotation
+ extra_transform = random.uniform(0, 1)
+ # zoom in and zoom out
+ if extra_transform > 0.75:
+ resize_coefficient = random.uniform(zoomin, zoomout)
+ region = region.resize((math.ceil(region_width * resize_coefficient), math.ceil(region_height * resize_coefficient)), Image.NEAREST)
+ m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
+ region_width, region_height = region.size
+ # rotation
+ elif extra_transform > 0.5:
+ m.paste(region, (y, x, y + region.size[0], x + region.size[1]))
+ m = m.rotate(random.randint(rotmin, rotmax))
+ # region_width, region_height = region.size
+ ### end
+ else:
+ m.paste(region, (y, x, y+region.size[0], x+region.size[1]))
+ masks.append(m.convert('L'))
+ return masks
+
+
+def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240):
+ '''
+ There is the initial point and 3 points per cubic bezier curve.
+ Thus, the curve will only pass though n points, which will be the sharp edges.
+ The other 2 modify the shape of the bezier curve.
+ edge_num, Number of possibly sharp edges
+ points_num, number of points in the Path
+ ratio, (0, 1) magnitude of the perturbation from the unit circle,
+ '''
+ points_num = edge_num*3 + 1
+ angles = np.linspace(0, 2*np.pi, points_num)
+ codes = np.full(points_num, Path.CURVE4)
+ codes[0] = Path.MOVETO
+ # Using this instead of Path.CLOSEPOLY avoids an innecessary straight line
+ verts = np.stack((np.cos(angles), np.sin(angles))).T * \
+ (2*ratio*np.random.random(points_num)+1-ratio)[:, None]
+ verts[-1, :] = verts[0, :]
+ path = Path(verts, codes)
+ # draw paths into images
+ fig = plt.figure()
+ ax = fig.add_subplot(111)
+ patch = patches.PathPatch(path, facecolor='black', lw=2)
+ ax.add_patch(patch)
+ ax.set_xlim(np.min(verts)*1.1, np.max(verts)*1.1)
+ ax.set_ylim(np.min(verts)*1.1, np.max(verts)*1.1)
+ ax.axis('off') # removes the axis to leave only the shape
+ fig.canvas.draw()
+ # convert plt images into numpy images
+ data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
+ data = data.reshape((fig.canvas.get_width_height()[::-1] + (3,)))
+ plt.close(fig)
+ # postprocess
+ data = cv2.resize(data, (width, height))[:, :, 0]
+ data = (1 - np.array(data > 0).astype(np.uint8))*255
+ corrdinates = np.where(data > 0)
+ xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max(
+ corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1])
+ region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax))
+ return region
+
+
+def random_accelerate(velocity, maxAcceleration, dist='uniform'):
+ speed, angle = velocity
+ d_speed, d_angle = maxAcceleration
+ if dist == 'uniform':
+ speed += np.random.uniform(-d_speed, d_speed)
+ angle += np.random.uniform(-d_angle, d_angle)
+ elif dist == 'guassian':
+ speed += np.random.normal(0, d_speed / 2)
+ angle += np.random.normal(0, d_angle / 2)
+ else:
+ raise NotImplementedError(
+ f'Distribution type {dist} is not supported.')
+ return (speed, angle)
+
+
+def get_random_velocity(max_speed=3, dist='uniform'):
+ if dist == 'uniform':
+ speed = np.random.uniform(max_speed)
+ elif dist == 'guassian':
+ speed = np.abs(np.random.normal(0, max_speed / 2))
+ else:
+ raise NotImplementedError(
+ f'Distribution type {dist} is not supported.')
+ angle = np.random.uniform(0, 2 * np.pi)
+ return (speed, angle)
+
+
+def random_move_control_points(X,
+ Y,
+ imageHeight,
+ imageWidth,
+ lineVelocity,
+ region_size,
+ maxLineAcceleration=(3, 0.5),
+ maxInitSpeed=3):
+ region_width, region_height = region_size
+ speed, angle = lineVelocity
+ X += int(speed * np.cos(angle))
+ Y += int(speed * np.sin(angle))
+ lineVelocity = random_accelerate(lineVelocity,
+ maxLineAcceleration,
+ dist='guassian')
+ if ((X > imageHeight - region_height) or (X < 0)
+ or (Y > imageWidth - region_width) or (Y < 0)):
+ lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian')
+ new_X = np.clip(X, 0, imageHeight - region_height)
+ new_Y = np.clip(Y, 0, imageWidth - region_width)
+ return new_X, new_Y, lineVelocity
+
+
+if __name__ == '__main__':
+
+ trials = 10
+ for _ in range(trials):
+ video_length = 10
+ # The returned masks are either stationary (50%) or moving (50%)
+ masks = create_random_shape_with_random_motion(video_length,
+ imageHeight=240,
+ imageWidth=432)
+
+ for m in masks:
+ cv2.imshow('mask', np.array(m))
+ cv2.waitKey(500)
diff --git a/propainter/inference.py b/propainter/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..b180d49edbcc03df8bb03a3fb8ac89d16b22855d
--- /dev/null
+++ b/propainter/inference.py
@@ -0,0 +1,520 @@
+# -*- coding: utf-8 -*-
+import os
+import cv2
+import numpy as np
+import scipy.ndimage
+from PIL import Image
+from tqdm import tqdm
+import torch
+import torchvision
+import gc
+
+try:
+ from model.modules.flow_comp_raft import RAFT_bi
+ from model.recurrent_flow_completion import RecurrentFlowCompleteNet
+ from model.propainter import InpaintGenerator
+ from utils.download_util import load_file_from_url
+ from core.utils import to_tensors
+ from model.misc import get_device
+except:
+ from propainter.model.modules.flow_comp_raft import RAFT_bi
+ from propainter.model.recurrent_flow_completion import RecurrentFlowCompleteNet
+ from propainter.model.propainter import InpaintGenerator
+ from propainter.utils.download_util import load_file_from_url
+ from propainter.core.utils import to_tensors
+ from propainter.model.misc import get_device
+
+import warnings
+warnings.filterwarnings("ignore")
+
+pretrain_model_url = 'https://github.com/sczhou/ProPainter/releases/download/v0.1.0/'
+MaxSideThresh = 960
+
+
+# resize frames
+def resize_frames(frames, size=None):
+ if size is not None:
+ out_size = size
+ process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
+ frames = [f.resize(process_size) for f in frames]
+ else:
+ out_size = frames[0].size
+ process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8)
+ if not out_size == process_size:
+ frames = [f.resize(process_size) for f in frames]
+
+ return frames, process_size, out_size
+
+# read frames from video
+def read_frame_from_videos(frame_root, video_length):
+ if frame_root.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
+ video_name = os.path.basename(frame_root)[:-4]
+ vframes, aframes, info = torchvision.io.read_video(filename=frame_root, pts_unit='sec', end_pts=video_length) # RGB
+ frames = list(vframes.numpy())
+ frames = [Image.fromarray(f) for f in frames]
+ fps = info['video_fps']
+ nframes = len(frames)
+ else:
+ video_name = os.path.basename(frame_root)
+ frames = []
+ fr_lst = sorted(os.listdir(frame_root))
+ for fr in fr_lst:
+ frame = cv2.imread(os.path.join(frame_root, fr))
+ frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
+ frames.append(frame)
+ fps = None
+ nframes = len(frames)
+ size = frames[0].size
+
+ return frames, fps, size, video_name, nframes
+
+def binary_mask(mask, th=0.1):
+ mask[mask>th] = 1
+ mask[mask<=th] = 0
+ return mask
+
+# read frame-wise masks
+def read_mask(mpath, frames_len, size, flow_mask_dilates=8, mask_dilates=5):
+ masks_img = []
+ masks_dilated = []
+ flow_masks = []
+
+ if mpath.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path
+ masks_img = [Image.open(mpath)]
+ elif mpath.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path
+ cap = cv2.VideoCapture(mpath)
+ if not cap.isOpened():
+ print("Error: Could not open video.")
+ exit()
+ idx = 0
+ while True:
+ ret, frame = cap.read()
+ if not ret:
+ break
+ if(idx >= frames_len):
+ break
+ masks_img.append(Image.fromarray(frame))
+ idx += 1
+ cap.release()
+ else:
+ mnames = sorted(os.listdir(mpath))
+ for mp in mnames:
+ masks_img.append(Image.open(os.path.join(mpath, mp)))
+ # print(mp)
+
+ for mask_img in masks_img:
+ if size is not None:
+ mask_img = mask_img.resize(size, Image.NEAREST)
+ mask_img = np.array(mask_img.convert('L'))
+
+ # Dilate 8 pixel so that all known pixel is trustworthy
+ if flow_mask_dilates > 0:
+ flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=flow_mask_dilates).astype(np.uint8)
+ else:
+ flow_mask_img = binary_mask(mask_img).astype(np.uint8)
+ # Close the small holes inside the foreground objects
+ # flow_mask_img = cv2.morphologyEx(flow_mask_img, cv2.MORPH_CLOSE, np.ones((21, 21),np.uint8)).astype(bool)
+ # flow_mask_img = scipy.ndimage.binary_fill_holes(flow_mask_img).astype(np.uint8)
+ flow_masks.append(Image.fromarray(flow_mask_img * 255))
+
+ if mask_dilates > 0:
+ mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=mask_dilates).astype(np.uint8)
+ else:
+ mask_img = binary_mask(mask_img).astype(np.uint8)
+ masks_dilated.append(Image.fromarray(mask_img * 255))
+
+ if len(masks_img) == 1:
+ flow_masks = flow_masks * frames_len
+ masks_dilated = masks_dilated * frames_len
+
+ return flow_masks, masks_dilated
+
+def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=-1):
+ ref_index = []
+ if ref_num == -1:
+ for i in range(0, length, ref_stride):
+ if i not in neighbor_ids:
+ ref_index.append(i)
+ else:
+ start_idx = max(0, mid_neighbor_id - ref_stride * (ref_num // 2))
+ end_idx = min(length, mid_neighbor_id + ref_stride * (ref_num // 2))
+ for i in range(start_idx, end_idx, ref_stride):
+ if i not in neighbor_ids:
+ if len(ref_index) > ref_num:
+ break
+ ref_index.append(i)
+ return ref_index
+
+
+class Propainter:
+ def __init__(
+ self, propainter_model_dir, device):
+ self.device = device
+ ##############################################
+ # set up RAFT and flow competition model
+ ##############################################
+ ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'raft-things.pth'),
+ model_dir=propainter_model_dir, progress=True, file_name=None)
+ self.fix_raft = RAFT_bi(ckpt_path, device)
+
+ ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'recurrent_flow_completion.pth'),
+ model_dir=propainter_model_dir, progress=True, file_name=None)
+ self.fix_flow_complete = RecurrentFlowCompleteNet(ckpt_path)
+ for p in self.fix_flow_complete.parameters():
+ p.requires_grad = False
+ self.fix_flow_complete.to(device)
+ self.fix_flow_complete.eval()
+
+ ##############################################
+ # set up ProPainter model
+ ##############################################
+ ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'ProPainter.pth'),
+ model_dir=propainter_model_dir, progress=True, file_name=None)
+ self.model = InpaintGenerator(model_path=ckpt_path).to(device)
+ self.model.eval()
+ def forward(self, video, mask, output_path, resize_ratio=0.6, video_length=2, height=-1, width=-1,
+ mask_dilation=4, ref_stride=10, neighbor_length=10, subvideo_length=80,
+ raft_iter=20, save_fps=24, save_frames=False, fp16=True):
+
+ # Use fp16 precision during inference to reduce running memory cost
+ use_half = True if fp16 else False
+ if self.device == torch.device('cpu'):
+ use_half = False
+
+ ################ read input video ################
+ frames, fps, size, video_name, nframes = read_frame_from_videos(video, video_length)
+ frames = frames[:nframes]
+ if not width == -1 and not height == -1:
+ size = (width, height)
+
+ longer_edge = max(size[0], size[1])
+ if(longer_edge > MaxSideThresh):
+ scale = MaxSideThresh / longer_edge
+ resize_ratio = resize_ratio * scale
+ if not resize_ratio == 1.0:
+ size = (int(resize_ratio * size[0]), int(resize_ratio * size[1]))
+
+ frames, size, out_size = resize_frames(frames, size)
+ fps = save_fps if fps is None else fps
+
+ ################ read mask ################
+ frames_len = len(frames)
+ flow_masks, masks_dilated = read_mask(mask, frames_len, size,
+ flow_mask_dilates=mask_dilation,
+ mask_dilates=mask_dilation)
+ flow_masks = flow_masks[:nframes]
+ masks_dilated = masks_dilated[:nframes]
+ w, h = size
+
+ ################ adjust input ################
+ frames_len = min(len(frames), len(masks_dilated))
+ frames = frames[:frames_len]
+ flow_masks = flow_masks[:frames_len]
+ masks_dilated = masks_dilated[:frames_len]
+
+ ori_frames_inp = [np.array(f).astype(np.uint8) for f in frames]
+ frames = to_tensors()(frames).unsqueeze(0) * 2 - 1
+ flow_masks = to_tensors()(flow_masks).unsqueeze(0)
+ masks_dilated = to_tensors()(masks_dilated).unsqueeze(0)
+ frames, flow_masks, masks_dilated = frames.to(self.device), flow_masks.to(self.device), masks_dilated.to(self.device)
+
+ ##############################################
+ # ProPainter inference
+ ##############################################
+ video_length = frames.size(1)
+ print(f'Priori generating: [{video_length} frames]...')
+ with torch.no_grad():
+ # ---- compute flow ----
+ new_longer_edge = max(frames.size(-1), frames.size(-2))
+ if new_longer_edge <= 640:
+ short_clip_len = 12
+ elif new_longer_edge <= 720:
+ short_clip_len = 8
+ elif new_longer_edge <= 1280:
+ short_clip_len = 4
+ else:
+ short_clip_len = 2
+
+ # use fp32 for RAFT
+ if frames.size(1) > short_clip_len:
+ gt_flows_f_list, gt_flows_b_list = [], []
+ for f in range(0, video_length, short_clip_len):
+ end_f = min(video_length, f + short_clip_len)
+ if f == 0:
+ flows_f, flows_b = self.fix_raft(frames[:,f:end_f], iters=raft_iter)
+ else:
+ flows_f, flows_b = self.fix_raft(frames[:,f-1:end_f], iters=raft_iter)
+
+ gt_flows_f_list.append(flows_f)
+ gt_flows_b_list.append(flows_b)
+ torch.cuda.empty_cache()
+
+ gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
+ gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
+ gt_flows_bi = (gt_flows_f, gt_flows_b)
+ else:
+ gt_flows_bi = self.fix_raft(frames, iters=raft_iter)
+ torch.cuda.empty_cache()
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ if use_half:
+ frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half()
+ gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half())
+ self.fix_flow_complete = self.fix_flow_complete.half()
+ self.model = self.model.half()
+
+ # ---- complete flow ----
+ flow_length = gt_flows_bi[0].size(1)
+ if flow_length > subvideo_length:
+ pred_flows_f, pred_flows_b = [], []
+ pad_len = 5
+ for f in range(0, flow_length, subvideo_length):
+ s_f = max(0, f - pad_len)
+ e_f = min(flow_length, f + subvideo_length + pad_len)
+ pad_len_s = max(0, f) - s_f
+ pad_len_e = e_f - min(flow_length, f + subvideo_length)
+ pred_flows_bi_sub, _ = self.fix_flow_complete.forward_bidirect_flow(
+ (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
+ flow_masks[:, s_f:e_f+1])
+ pred_flows_bi_sub = self.fix_flow_complete.combine_flow(
+ (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]),
+ pred_flows_bi_sub,
+ flow_masks[:, s_f:e_f+1])
+
+ pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
+ pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
+ torch.cuda.empty_cache()
+
+ pred_flows_f = torch.cat(pred_flows_f, dim=1)
+ pred_flows_b = torch.cat(pred_flows_b, dim=1)
+ pred_flows_bi = (pred_flows_f, pred_flows_b)
+ else:
+ pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks)
+ pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks)
+ torch.cuda.empty_cache()
+ torch.cuda.empty_cache()
+ gc.collect()
+
+
+ masks_dilated_ori = masks_dilated.clone()
+ # ---- Pre-propagation ----
+ subvideo_length_img_prop = min(100, subvideo_length) # ensure a minimum of 100 frames for image propagation
+ if(len(frames[0]))>subvideo_length_img_prop: # perform propagation only when length of frames is larger than subvideo_length_img_prop
+ sample_rate = len(frames[0])//(subvideo_length_img_prop//2)
+ index_sample = list(range(0, len(frames[0]), sample_rate))
+ sample_frames = torch.stack([frames[0][i].to(torch.float32) for i in index_sample]).unsqueeze(0) # use fp32 for RAFT
+ sample_masks_dilated = torch.stack([masks_dilated[0][i] for i in index_sample]).unsqueeze(0)
+ sample_flow_masks = torch.stack([flow_masks[0][i] for i in index_sample]).unsqueeze(0)
+
+ ## recompute flow for sampled frames
+ # use fp32 for RAFT
+ sample_video_length = sample_frames.size(1)
+ if sample_frames.size(1) > short_clip_len:
+ gt_flows_f_list, gt_flows_b_list = [], []
+ for f in range(0, sample_video_length, short_clip_len):
+ end_f = min(sample_video_length, f + short_clip_len)
+ if f == 0:
+ flows_f, flows_b = self.fix_raft(sample_frames[:,f:end_f], iters=raft_iter)
+ else:
+ flows_f, flows_b = self.fix_raft(sample_frames[:,f-1:end_f], iters=raft_iter)
+
+ gt_flows_f_list.append(flows_f)
+ gt_flows_b_list.append(flows_b)
+ torch.cuda.empty_cache()
+
+ gt_flows_f = torch.cat(gt_flows_f_list, dim=1)
+ gt_flows_b = torch.cat(gt_flows_b_list, dim=1)
+ sample_gt_flows_bi = (gt_flows_f, gt_flows_b)
+ else:
+ sample_gt_flows_bi = self.fix_raft(sample_frames, iters=raft_iter)
+ torch.cuda.empty_cache()
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ if use_half:
+ sample_frames, sample_flow_masks, sample_masks_dilated = sample_frames.half(), sample_flow_masks.half(), sample_masks_dilated.half()
+ sample_gt_flows_bi = (sample_gt_flows_bi[0].half(), sample_gt_flows_bi[1].half())
+
+ # ---- complete flow ----
+ flow_length = sample_gt_flows_bi[0].size(1)
+ if flow_length > subvideo_length:
+ pred_flows_f, pred_flows_b = [], []
+ pad_len = 5
+ for f in range(0, flow_length, subvideo_length):
+ s_f = max(0, f - pad_len)
+ e_f = min(flow_length, f + subvideo_length + pad_len)
+ pad_len_s = max(0, f) - s_f
+ pad_len_e = e_f - min(flow_length, f + subvideo_length)
+ pred_flows_bi_sub, _ = self.fix_flow_complete.forward_bidirect_flow(
+ (sample_gt_flows_bi[0][:, s_f:e_f], sample_gt_flows_bi[1][:, s_f:e_f]),
+ sample_flow_masks[:, s_f:e_f+1])
+ pred_flows_bi_sub = self.fix_flow_complete.combine_flow(
+ (sample_gt_flows_bi[0][:, s_f:e_f], sample_gt_flows_bi[1][:, s_f:e_f]),
+ pred_flows_bi_sub,
+ sample_flow_masks[:, s_f:e_f+1])
+
+ pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e])
+ pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e])
+ torch.cuda.empty_cache()
+
+ pred_flows_f = torch.cat(pred_flows_f, dim=1)
+ pred_flows_b = torch.cat(pred_flows_b, dim=1)
+ sample_pred_flows_bi = (pred_flows_f, pred_flows_b)
+ else:
+ sample_pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(sample_gt_flows_bi, sample_flow_masks)
+ sample_pred_flows_bi = self.fix_flow_complete.combine_flow(sample_gt_flows_bi, sample_pred_flows_bi, sample_flow_masks)
+ torch.cuda.empty_cache()
+ torch.cuda.empty_cache()
+ gc.collect()
+
+ masked_frames = sample_frames * (1 - sample_masks_dilated)
+
+ if sample_video_length > subvideo_length_img_prop:
+ updated_frames, updated_masks = [], []
+ pad_len = 10
+ for f in range(0, sample_video_length, subvideo_length_img_prop):
+ s_f = max(0, f - pad_len)
+ e_f = min(sample_video_length, f + subvideo_length_img_prop + pad_len)
+ pad_len_s = max(0, f) - s_f
+ pad_len_e = e_f - min(sample_video_length, f + subvideo_length_img_prop)
+
+ b, t, _, _, _ = sample_masks_dilated[:, s_f:e_f].size()
+ pred_flows_bi_sub = (sample_pred_flows_bi[0][:, s_f:e_f-1], sample_pred_flows_bi[1][:, s_f:e_f-1])
+ prop_imgs_sub, updated_local_masks_sub = self.model.img_propagation(masked_frames[:, s_f:e_f],
+ pred_flows_bi_sub,
+ sample_masks_dilated[:, s_f:e_f],
+ 'nearest')
+ updated_frames_sub = sample_frames[:, s_f:e_f] * (1 - sample_masks_dilated[:, s_f:e_f]) + \
+ prop_imgs_sub.view(b, t, 3, h, w) * sample_masks_dilated[:, s_f:e_f]
+ updated_masks_sub = updated_local_masks_sub.view(b, t, 1, h, w)
+
+ updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
+ updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
+ torch.cuda.empty_cache()
+
+ updated_frames = torch.cat(updated_frames, dim=1)
+ updated_masks = torch.cat(updated_masks, dim=1)
+ else:
+ b, t, _, _, _ = sample_masks_dilated.size()
+ prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, sample_pred_flows_bi, sample_masks_dilated, 'nearest')
+ updated_frames = sample_frames * (1 - sample_masks_dilated) + prop_imgs.view(b, t, 3, h, w) * sample_masks_dilated
+ updated_masks = updated_local_masks.view(b, t, 1, h, w)
+ torch.cuda.empty_cache()
+
+ ## replace input frames/masks with updated frames/masks
+ for i,index in enumerate(index_sample):
+ frames[0][index] = updated_frames[0][i]
+ masks_dilated[0][index] = updated_masks[0][i]
+
+
+ # ---- frame-by-frame image propagation ----
+ masked_frames = frames * (1 - masks_dilated)
+ subvideo_length_img_prop = min(100, subvideo_length) # ensure a minimum of 100 frames for image propagation
+ if video_length > subvideo_length_img_prop:
+ updated_frames, updated_masks = [], []
+ pad_len = 10
+ for f in range(0, video_length, subvideo_length_img_prop):
+ s_f = max(0, f - pad_len)
+ e_f = min(video_length, f + subvideo_length_img_prop + pad_len)
+ pad_len_s = max(0, f) - s_f
+ pad_len_e = e_f - min(video_length, f + subvideo_length_img_prop)
+
+ b, t, _, _, _ = masks_dilated[:, s_f:e_f].size()
+ pred_flows_bi_sub = (pred_flows_bi[0][:, s_f:e_f-1], pred_flows_bi[1][:, s_f:e_f-1])
+ prop_imgs_sub, updated_local_masks_sub = self.model.img_propagation(masked_frames[:, s_f:e_f],
+ pred_flows_bi_sub,
+ masks_dilated[:, s_f:e_f],
+ 'nearest')
+ updated_frames_sub = frames[:, s_f:e_f] * (1 - masks_dilated[:, s_f:e_f]) + \
+ prop_imgs_sub.view(b, t, 3, h, w) * masks_dilated[:, s_f:e_f]
+ updated_masks_sub = updated_local_masks_sub.view(b, t, 1, h, w)
+
+ updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e])
+ updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e])
+ torch.cuda.empty_cache()
+
+ updated_frames = torch.cat(updated_frames, dim=1)
+ updated_masks = torch.cat(updated_masks, dim=1)
+ else:
+ b, t, _, _, _ = masks_dilated.size()
+ prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest')
+ updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated
+ updated_masks = updated_local_masks.view(b, t, 1, h, w)
+ torch.cuda.empty_cache()
+
+ comp_frames = [None] * video_length
+
+ neighbor_stride = neighbor_length // 2
+ if video_length > subvideo_length:
+ ref_num = subvideo_length // ref_stride
+ else:
+ ref_num = -1
+
+ torch.cuda.empty_cache()
+ # ---- feature propagation + transformer ----
+ for f in tqdm(range(0, video_length, neighbor_stride)):
+ neighbor_ids = [
+ i for i in range(max(0, f - neighbor_stride),
+ min(video_length, f + neighbor_stride + 1))
+ ]
+ ref_ids = get_ref_index(f, neighbor_ids, video_length, ref_stride, ref_num)
+ selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :]
+ selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :]
+ selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :]
+ selected_pred_flows_bi = (pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :])
+
+ with torch.no_grad():
+ # 1.0 indicates mask
+ l_t = len(neighbor_ids)
+
+ # pred_img = selected_imgs # results of image propagation
+ pred_img = self.model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t)
+ pred_img = pred_img.view(-1, 3, h, w)
+
+ ## compose with input frames
+ pred_img = (pred_img + 1) / 2
+ pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255
+ binary_masks = masks_dilated_ori[0, neighbor_ids, :, :, :].cpu().permute(
+ 0, 2, 3, 1).numpy().astype(np.uint8) # use original mask
+ for i in range(len(neighbor_ids)):
+ idx = neighbor_ids[i]
+ img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \
+ + ori_frames_inp[idx] * (1 - binary_masks[i])
+ if comp_frames[idx] is None:
+ comp_frames[idx] = img
+ else:
+ comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5
+
+ comp_frames[idx] = comp_frames[idx].astype(np.uint8)
+
+ torch.cuda.empty_cache()
+
+ ##save composed video##
+ comp_frames = [cv2.resize(f, out_size) for f in comp_frames]
+ writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"),
+ fps, (comp_frames[0].shape[1],comp_frames[0].shape[0]))
+ for f in range(video_length):
+ frame = comp_frames[f].astype(np.uint8)
+ writer.write(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
+ writer.release()
+
+ torch.cuda.empty_cache()
+
+ return output_path
+
+
+
+if __name__ == '__main__':
+
+ device = get_device()
+ propainter_model_dir = "weights/propainter"
+ propainter = Propainter(propainter_model_dir, device=device)
+
+ video = "examples/example1/video.mp4"
+ mask = "examples/example1/mask.mp4"
+ output = "results/priori.mp4"
+ res = propainter.forward(video, mask, output)
+
+
+
\ No newline at end of file
diff --git a/propainter/model/__init__.py b/propainter/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/propainter/model/__init__.py
@@ -0,0 +1 @@
+
diff --git a/propainter/model/canny/canny_filter.py b/propainter/model/canny/canny_filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d16195c9355b506e22a2ba527006adb9c541a7c
--- /dev/null
+++ b/propainter/model/canny/canny_filter.py
@@ -0,0 +1,256 @@
+import math
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .gaussian import gaussian_blur2d
+from .kernels import get_canny_nms_kernel, get_hysteresis_kernel
+from .sobel import spatial_gradient
+
+def rgb_to_grayscale(image, rgb_weights = None):
+ if len(image.shape) < 3 or image.shape[-3] != 3:
+ raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}")
+
+ if rgb_weights is None:
+ # 8 bit images
+ if image.dtype == torch.uint8:
+ rgb_weights = torch.tensor([76, 150, 29], device=image.device, dtype=torch.uint8)
+ # floating point images
+ elif image.dtype in (torch.float16, torch.float32, torch.float64):
+ rgb_weights = torch.tensor([0.299, 0.587, 0.114], device=image.device, dtype=image.dtype)
+ else:
+ raise TypeError(f"Unknown data type: {image.dtype}")
+ else:
+ # is tensor that we make sure is in the same device/dtype
+ rgb_weights = rgb_weights.to(image)
+
+ # unpack the color image channels with RGB order
+ r = image[..., 0:1, :, :]
+ g = image[..., 1:2, :, :]
+ b = image[..., 2:3, :, :]
+
+ w_r, w_g, w_b = rgb_weights.unbind()
+ return w_r * r + w_g * g + w_b * b
+
+
+def canny(
+ input: torch.Tensor,
+ low_threshold: float = 0.1,
+ high_threshold: float = 0.2,
+ kernel_size: Tuple[int, int] = (5, 5),
+ sigma: Tuple[float, float] = (1, 1),
+ hysteresis: bool = True,
+ eps: float = 1e-6,
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ r"""Find edges of the input image and filters them using the Canny algorithm.
+
+ .. image:: _static/img/canny.png
+
+ Args:
+ input: input image tensor with shape :math:`(B,C,H,W)`.
+ low_threshold: lower threshold for the hysteresis procedure.
+ high_threshold: upper threshold for the hysteresis procedure.
+ kernel_size: the size of the kernel for the gaussian blur.
+ sigma: the standard deviation of the kernel for the gaussian blur.
+ hysteresis: if True, applies the hysteresis edge tracking.
+ Otherwise, the edges are divided between weak (0.5) and strong (1) edges.
+ eps: regularization number to avoid NaN during backprop.
+
+ Returns:
+ - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`.
+ - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`.
+
+ .. note::
+ See a working example `here `__.
+
+ Example:
+ >>> input = torch.rand(5, 3, 4, 4)
+ >>> magnitude, edges = canny(input) # 5x3x4x4
+ >>> magnitude.shape
+ torch.Size([5, 1, 4, 4])
+ >>> edges.shape
+ torch.Size([5, 1, 4, 4])
+ """
+ if not isinstance(input, torch.Tensor):
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
+
+ if not len(input.shape) == 4:
+ raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
+
+ if low_threshold > high_threshold:
+ raise ValueError(
+ "Invalid input thresholds. low_threshold should be smaller than the high_threshold. Got: {}>{}".format(
+ low_threshold, high_threshold
+ )
+ )
+
+ if low_threshold < 0 and low_threshold > 1:
+ raise ValueError(f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}")
+
+ if high_threshold < 0 and high_threshold > 1:
+ raise ValueError(f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}")
+
+ device: torch.device = input.device
+ dtype: torch.dtype = input.dtype
+
+ # To Grayscale
+ if input.shape[1] == 3:
+ input = rgb_to_grayscale(input)
+
+ # Gaussian filter
+ blurred: torch.Tensor = gaussian_blur2d(input, kernel_size, sigma)
+
+ # Compute the gradients
+ gradients: torch.Tensor = spatial_gradient(blurred, normalized=False)
+
+ # Unpack the edges
+ gx: torch.Tensor = gradients[:, :, 0]
+ gy: torch.Tensor = gradients[:, :, 1]
+
+ # Compute gradient magnitude and angle
+ magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps)
+ angle: torch.Tensor = torch.atan2(gy, gx)
+
+ # Radians to Degrees
+ angle = 180.0 * angle / math.pi
+
+ # Round angle to the nearest 45 degree
+ angle = torch.round(angle / 45) * 45
+
+ # Non-maximal suppression
+ nms_kernels: torch.Tensor = get_canny_nms_kernel(device, dtype)
+ nms_magnitude: torch.Tensor = F.conv2d(magnitude, nms_kernels, padding=nms_kernels.shape[-1] // 2)
+
+ # Get the indices for both directions
+ positive_idx: torch.Tensor = (angle / 45) % 8
+ positive_idx = positive_idx.long()
+
+ negative_idx: torch.Tensor = ((angle / 45) + 4) % 8
+ negative_idx = negative_idx.long()
+
+ # Apply the non-maximum suppression to the different directions
+ channel_select_filtered_positive: torch.Tensor = torch.gather(nms_magnitude, 1, positive_idx)
+ channel_select_filtered_negative: torch.Tensor = torch.gather(nms_magnitude, 1, negative_idx)
+
+ channel_select_filtered: torch.Tensor = torch.stack(
+ [channel_select_filtered_positive, channel_select_filtered_negative], 1
+ )
+
+ is_max: torch.Tensor = channel_select_filtered.min(dim=1)[0] > 0.0
+
+ magnitude = magnitude * is_max
+
+ # Threshold
+ edges: torch.Tensor = F.threshold(magnitude, low_threshold, 0.0)
+
+ low: torch.Tensor = magnitude > low_threshold
+ high: torch.Tensor = magnitude > high_threshold
+
+ edges = low * 0.5 + high * 0.5
+ edges = edges.to(dtype)
+
+ # Hysteresis
+ if hysteresis:
+ edges_old: torch.Tensor = -torch.ones(edges.shape, device=edges.device, dtype=dtype)
+ hysteresis_kernels: torch.Tensor = get_hysteresis_kernel(device, dtype)
+
+ while ((edges_old - edges).abs() != 0).any():
+ weak: torch.Tensor = (edges == 0.5).float()
+ strong: torch.Tensor = (edges == 1).float()
+
+ hysteresis_magnitude: torch.Tensor = F.conv2d(
+ edges, hysteresis_kernels, padding=hysteresis_kernels.shape[-1] // 2
+ )
+ hysteresis_magnitude = (hysteresis_magnitude == 1).any(1, keepdim=True).to(dtype)
+ hysteresis_magnitude = hysteresis_magnitude * weak + strong
+
+ edges_old = edges.clone()
+ edges = hysteresis_magnitude + (hysteresis_magnitude == 0) * weak * 0.5
+
+ edges = hysteresis_magnitude
+
+ return magnitude, edges
+
+
+class Canny(nn.Module):
+ r"""Module that finds edges of the input image and filters them using the Canny algorithm.
+
+ Args:
+ input: input image tensor with shape :math:`(B,C,H,W)`.
+ low_threshold: lower threshold for the hysteresis procedure.
+ high_threshold: upper threshold for the hysteresis procedure.
+ kernel_size: the size of the kernel for the gaussian blur.
+ sigma: the standard deviation of the kernel for the gaussian blur.
+ hysteresis: if True, applies the hysteresis edge tracking.
+ Otherwise, the edges are divided between weak (0.5) and strong (1) edges.
+ eps: regularization number to avoid NaN during backprop.
+
+ Returns:
+ - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`.
+ - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`.
+
+ Example:
+ >>> input = torch.rand(5, 3, 4, 4)
+ >>> magnitude, edges = Canny()(input) # 5x3x4x4
+ >>> magnitude.shape
+ torch.Size([5, 1, 4, 4])
+ >>> edges.shape
+ torch.Size([5, 1, 4, 4])
+ """
+
+ def __init__(
+ self,
+ low_threshold: float = 0.1,
+ high_threshold: float = 0.2,
+ kernel_size: Tuple[int, int] = (5, 5),
+ sigma: Tuple[float, float] = (1, 1),
+ hysteresis: bool = True,
+ eps: float = 1e-6,
+ ) -> None:
+ super().__init__()
+
+ if low_threshold > high_threshold:
+ raise ValueError(
+ "Invalid input thresholds. low_threshold should be\
+ smaller than the high_threshold. Got: {}>{}".format(
+ low_threshold, high_threshold
+ )
+ )
+
+ if low_threshold < 0 or low_threshold > 1:
+ raise ValueError(f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}")
+
+ if high_threshold < 0 or high_threshold > 1:
+ raise ValueError(f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}")
+
+ # Gaussian blur parameters
+ self.kernel_size = kernel_size
+ self.sigma = sigma
+
+ # Double threshold
+ self.low_threshold = low_threshold
+ self.high_threshold = high_threshold
+
+ # Hysteresis
+ self.hysteresis = hysteresis
+
+ self.eps: float = eps
+
+ def __repr__(self) -> str:
+ return ''.join(
+ (
+ f'{type(self).__name__}(',
+ ', '.join(
+ f'{name}={getattr(self, name)}' for name in sorted(self.__dict__) if not name.startswith('_')
+ ),
+ ')',
+ )
+ )
+
+ def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ return canny(
+ input, self.low_threshold, self.high_threshold, self.kernel_size, self.sigma, self.hysteresis, self.eps
+ )
\ No newline at end of file
diff --git a/propainter/model/canny/filter.py b/propainter/model/canny/filter.py
new file mode 100644
index 0000000000000000000000000000000000000000..e39d44d67a067c56f994dc9a189f3cf98663bf68
--- /dev/null
+++ b/propainter/model/canny/filter.py
@@ -0,0 +1,288 @@
+from typing import List
+
+import torch
+import torch.nn.functional as F
+
+from .kernels import normalize_kernel2d
+
+
+def _compute_padding(kernel_size: List[int]) -> List[int]:
+ """Compute padding tuple."""
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
+ if len(kernel_size) < 2:
+ raise AssertionError(kernel_size)
+ computed = [k - 1 for k in kernel_size]
+
+ # for even kernels we need to do asymmetric padding :(
+ out_padding = 2 * len(kernel_size) * [0]
+
+ for i in range(len(kernel_size)):
+ computed_tmp = computed[-(i + 1)]
+
+ pad_front = computed_tmp // 2
+ pad_rear = computed_tmp - pad_front
+
+ out_padding[2 * i + 0] = pad_front
+ out_padding[2 * i + 1] = pad_rear
+
+ return out_padding
+
+
+def filter2d(
+ input: torch.Tensor,
+ kernel: torch.Tensor,
+ border_type: str = 'reflect',
+ normalized: bool = False,
+ padding: str = 'same',
+) -> torch.Tensor:
+ r"""Convolve a tensor with a 2d kernel.
+
+ The function applies a given kernel to a tensor. The kernel is applied
+ independently at each depth channel of the tensor. Before applying the
+ kernel, the function applies padding according to the specified mode so
+ that the output remains in the same shape.
+
+ Args:
+ input: the input tensor with shape of
+ :math:`(B, C, H, W)`.
+ kernel: the kernel to be convolved with the input
+ tensor. The kernel shape must be :math:`(1, kH, kW)` or :math:`(B, kH, kW)`.
+ border_type: the padding mode to be applied before convolving.
+ The expected modes are: ``'constant'``, ``'reflect'``,
+ ``'replicate'`` or ``'circular'``.
+ normalized: If True, kernel will be L1 normalized.
+ padding: This defines the type of padding.
+ 2 modes available ``'same'`` or ``'valid'``.
+
+ Return:
+ torch.Tensor: the convolved tensor of same size and numbers of channels
+ as the input with shape :math:`(B, C, H, W)`.
+
+ Example:
+ >>> input = torch.tensor([[[
+ ... [0., 0., 0., 0., 0.],
+ ... [0., 0., 0., 0., 0.],
+ ... [0., 0., 5., 0., 0.],
+ ... [0., 0., 0., 0., 0.],
+ ... [0., 0., 0., 0., 0.],]]])
+ >>> kernel = torch.ones(1, 3, 3)
+ >>> filter2d(input, kernel, padding='same')
+ tensor([[[[0., 0., 0., 0., 0.],
+ [0., 5., 5., 5., 0.],
+ [0., 5., 5., 5., 0.],
+ [0., 5., 5., 5., 0.],
+ [0., 0., 0., 0., 0.]]]])
+ """
+ if not isinstance(input, torch.Tensor):
+ raise TypeError(f"Input input is not torch.Tensor. Got {type(input)}")
+
+ if not isinstance(kernel, torch.Tensor):
+ raise TypeError(f"Input kernel is not torch.Tensor. Got {type(kernel)}")
+
+ if not isinstance(border_type, str):
+ raise TypeError(f"Input border_type is not string. Got {type(border_type)}")
+
+ if border_type not in ['constant', 'reflect', 'replicate', 'circular']:
+ raise ValueError(
+ f"Invalid border type, we expect 'constant', \
+ 'reflect', 'replicate', 'circular'. Got:{border_type}"
+ )
+
+ if not isinstance(padding, str):
+ raise TypeError(f"Input padding is not string. Got {type(padding)}")
+
+ if padding not in ['valid', 'same']:
+ raise ValueError(f"Invalid padding mode, we expect 'valid' or 'same'. Got: {padding}")
+
+ if not len(input.shape) == 4:
+ raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
+
+ if (not len(kernel.shape) == 3) and not ((kernel.shape[0] == 0) or (kernel.shape[0] == input.shape[0])):
+ raise ValueError(f"Invalid kernel shape, we expect 1xHxW or BxHxW. Got: {kernel.shape}")
+
+ # prepare kernel
+ b, c, h, w = input.shape
+ tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input)
+
+ if normalized:
+ tmp_kernel = normalize_kernel2d(tmp_kernel)
+
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
+
+ height, width = tmp_kernel.shape[-2:]
+
+ # pad the input tensor
+ if padding == 'same':
+ padding_shape: List[int] = _compute_padding([height, width])
+ input = F.pad(input, padding_shape, mode=border_type)
+
+ # kernel and input tensor reshape to align element-wise or batch-wise params
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
+
+ # convolve the tensor with the kernel.
+ output = F.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
+
+ if padding == 'same':
+ out = output.view(b, c, h, w)
+ else:
+ out = output.view(b, c, h - height + 1, w - width + 1)
+
+ return out
+
+
+def filter2d_separable(
+ input: torch.Tensor,
+ kernel_x: torch.Tensor,
+ kernel_y: torch.Tensor,
+ border_type: str = 'reflect',
+ normalized: bool = False,
+ padding: str = 'same',
+) -> torch.Tensor:
+ r"""Convolve a tensor with two 1d kernels, in x and y directions.
+
+ The function applies a given kernel to a tensor. The kernel is applied
+ independently at each depth channel of the tensor. Before applying the
+ kernel, the function applies padding according to the specified mode so
+ that the output remains in the same shape.
+
+ Args:
+ input: the input tensor with shape of
+ :math:`(B, C, H, W)`.
+ kernel_x: the kernel to be convolved with the input
+ tensor. The kernel shape must be :math:`(1, kW)` or :math:`(B, kW)`.
+ kernel_y: the kernel to be convolved with the input
+ tensor. The kernel shape must be :math:`(1, kH)` or :math:`(B, kH)`.
+ border_type: the padding mode to be applied before convolving.
+ The expected modes are: ``'constant'``, ``'reflect'``,
+ ``'replicate'`` or ``'circular'``.
+ normalized: If True, kernel will be L1 normalized.
+ padding: This defines the type of padding.
+ 2 modes available ``'same'`` or ``'valid'``.
+
+ Return:
+ torch.Tensor: the convolved tensor of same size and numbers of channels
+ as the input with shape :math:`(B, C, H, W)`.
+
+ Example:
+ >>> input = torch.tensor([[[
+ ... [0., 0., 0., 0., 0.],
+ ... [0., 0., 0., 0., 0.],
+ ... [0., 0., 5., 0., 0.],
+ ... [0., 0., 0., 0., 0.],
+ ... [0., 0., 0., 0., 0.],]]])
+ >>> kernel = torch.ones(1, 3)
+
+ >>> filter2d_separable(input, kernel, kernel, padding='same')
+ tensor([[[[0., 0., 0., 0., 0.],
+ [0., 5., 5., 5., 0.],
+ [0., 5., 5., 5., 0.],
+ [0., 5., 5., 5., 0.],
+ [0., 0., 0., 0., 0.]]]])
+ """
+ out_x = filter2d(input, kernel_x.unsqueeze(0), border_type, normalized, padding)
+ out = filter2d(out_x, kernel_y.unsqueeze(-1), border_type, normalized, padding)
+ return out
+
+
+def filter3d(
+ input: torch.Tensor, kernel: torch.Tensor, border_type: str = 'replicate', normalized: bool = False
+) -> torch.Tensor:
+ r"""Convolve a tensor with a 3d kernel.
+
+ The function applies a given kernel to a tensor. The kernel is applied
+ independently at each depth channel of the tensor. Before applying the
+ kernel, the function applies padding according to the specified mode so
+ that the output remains in the same shape.
+
+ Args:
+ input: the input tensor with shape of
+ :math:`(B, C, D, H, W)`.
+ kernel: the kernel to be convolved with the input
+ tensor. The kernel shape must be :math:`(1, kD, kH, kW)` or :math:`(B, kD, kH, kW)`.
+ border_type: the padding mode to be applied before convolving.
+ The expected modes are: ``'constant'``,
+ ``'replicate'`` or ``'circular'``.
+ normalized: If True, kernel will be L1 normalized.
+
+ Return:
+ the convolved tensor of same size and numbers of channels
+ as the input with shape :math:`(B, C, D, H, W)`.
+
+ Example:
+ >>> input = torch.tensor([[[
+ ... [[0., 0., 0., 0., 0.],
+ ... [0., 0., 0., 0., 0.],
+ ... [0., 0., 0., 0., 0.],
+ ... [0., 0., 0., 0., 0.],
+ ... [0., 0., 0., 0., 0.]],
+ ... [[0., 0., 0., 0., 0.],
+ ... [0., 0., 0., 0., 0.],
+ ... [0., 0., 5., 0., 0.],
+ ... [0., 0., 0., 0., 0.],
+ ... [0., 0., 0., 0., 0.]],
+ ... [[0., 0., 0., 0., 0.],
+ ... [0., 0., 0., 0., 0.],
+ ... [0., 0., 0., 0., 0.],
+ ... [0., 0., 0., 0., 0.],
+ ... [0., 0., 0., 0., 0.]]
+ ... ]]])
+ >>> kernel = torch.ones(1, 3, 3, 3)
+ >>> filter3d(input, kernel)
+ tensor([[[[[0., 0., 0., 0., 0.],
+ [0., 5., 5., 5., 0.],
+ [0., 5., 5., 5., 0.],
+ [0., 5., 5., 5., 0.],
+ [0., 0., 0., 0., 0.]],
+
+ [[0., 0., 0., 0., 0.],
+ [0., 5., 5., 5., 0.],
+ [0., 5., 5., 5., 0.],
+ [0., 5., 5., 5., 0.],
+ [0., 0., 0., 0., 0.]],
+
+ [[0., 0., 0., 0., 0.],
+ [0., 5., 5., 5., 0.],
+ [0., 5., 5., 5., 0.],
+ [0., 5., 5., 5., 0.],
+ [0., 0., 0., 0., 0.]]]]])
+ """
+ if not isinstance(input, torch.Tensor):
+ raise TypeError(f"Input border_type is not torch.Tensor. Got {type(input)}")
+
+ if not isinstance(kernel, torch.Tensor):
+ raise TypeError(f"Input border_type is not torch.Tensor. Got {type(kernel)}")
+
+ if not isinstance(border_type, str):
+ raise TypeError(f"Input border_type is not string. Got {type(kernel)}")
+
+ if not len(input.shape) == 5:
+ raise ValueError(f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}")
+
+ if not len(kernel.shape) == 4 and kernel.shape[0] != 1:
+ raise ValueError(f"Invalid kernel shape, we expect 1xDxHxW. Got: {kernel.shape}")
+
+ # prepare kernel
+ b, c, d, h, w = input.shape
+ tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input)
+
+ if normalized:
+ bk, dk, hk, wk = kernel.shape
+ tmp_kernel = normalize_kernel2d(tmp_kernel.view(bk, dk, hk * wk)).view_as(tmp_kernel)
+
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1, -1)
+
+ # pad the input tensor
+ depth, height, width = tmp_kernel.shape[-3:]
+ padding_shape: List[int] = _compute_padding([depth, height, width])
+ input_pad: torch.Tensor = F.pad(input, padding_shape, mode=border_type)
+
+ # kernel and input tensor reshape to align element-wise or batch-wise params
+ tmp_kernel = tmp_kernel.reshape(-1, 1, depth, height, width)
+ input_pad = input_pad.view(-1, tmp_kernel.size(0), input_pad.size(-3), input_pad.size(-2), input_pad.size(-1))
+
+ # convolve the tensor with the kernel.
+ output = F.conv3d(input_pad, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
+
+ return output.view(b, c, d, h, w)
\ No newline at end of file
diff --git a/propainter/model/canny/gaussian.py b/propainter/model/canny/gaussian.py
new file mode 100644
index 0000000000000000000000000000000000000000..182f05c5d7d297d97b3dd008287e053493350bb6
--- /dev/null
+++ b/propainter/model/canny/gaussian.py
@@ -0,0 +1,116 @@
+from typing import Tuple
+
+import torch
+import torch.nn as nn
+
+from .filter import filter2d, filter2d_separable
+from .kernels import get_gaussian_kernel1d, get_gaussian_kernel2d
+
+
+def gaussian_blur2d(
+ input: torch.Tensor,
+ kernel_size: Tuple[int, int],
+ sigma: Tuple[float, float],
+ border_type: str = 'reflect',
+ separable: bool = True,
+) -> torch.Tensor:
+ r"""Create an operator that blurs a tensor using a Gaussian filter.
+
+ .. image:: _static/img/gaussian_blur2d.png
+
+ The operator smooths the given tensor with a gaussian kernel by convolving
+ it to each channel. It supports batched operation.
+
+ Arguments:
+ input: the input tensor with shape :math:`(B,C,H,W)`.
+ kernel_size: the size of the kernel.
+ sigma: the standard deviation of the kernel.
+ border_type: the padding mode to be applied before convolving.
+ The expected modes are: ``'constant'``, ``'reflect'``,
+ ``'replicate'`` or ``'circular'``. Default: ``'reflect'``.
+ separable: run as composition of two 1d-convolutions.
+
+ Returns:
+ the blurred tensor with shape :math:`(B, C, H, W)`.
+
+ .. note::
+ See a working example `here `__.
+
+ Examples:
+ >>> input = torch.rand(2, 4, 5, 5)
+ >>> output = gaussian_blur2d(input, (3, 3), (1.5, 1.5))
+ >>> output.shape
+ torch.Size([2, 4, 5, 5])
+ """
+ if separable:
+ kernel_x: torch.Tensor = get_gaussian_kernel1d(kernel_size[1], sigma[1])
+ kernel_y: torch.Tensor = get_gaussian_kernel1d(kernel_size[0], sigma[0])
+ out = filter2d_separable(input, kernel_x[None], kernel_y[None], border_type)
+ else:
+ kernel: torch.Tensor = get_gaussian_kernel2d(kernel_size, sigma)
+ out = filter2d(input, kernel[None], border_type)
+ return out
+
+
+class GaussianBlur2d(nn.Module):
+ r"""Create an operator that blurs a tensor using a Gaussian filter.
+
+ The operator smooths the given tensor with a gaussian kernel by convolving
+ it to each channel. It supports batched operation.
+
+ Arguments:
+ kernel_size: the size of the kernel.
+ sigma: the standard deviation of the kernel.
+ border_type: the padding mode to be applied before convolving.
+ The expected modes are: ``'constant'``, ``'reflect'``,
+ ``'replicate'`` or ``'circular'``. Default: ``'reflect'``.
+ separable: run as composition of two 1d-convolutions.
+
+ Returns:
+ the blurred tensor.
+
+ Shape:
+ - Input: :math:`(B, C, H, W)`
+ - Output: :math:`(B, C, H, W)`
+
+ Examples::
+
+ >>> input = torch.rand(2, 4, 5, 5)
+ >>> gauss = GaussianBlur2d((3, 3), (1.5, 1.5))
+ >>> output = gauss(input) # 2x4x5x5
+ >>> output.shape
+ torch.Size([2, 4, 5, 5])
+ """
+
+ def __init__(
+ self,
+ kernel_size: Tuple[int, int],
+ sigma: Tuple[float, float],
+ border_type: str = 'reflect',
+ separable: bool = True,
+ ) -> None:
+ super().__init__()
+ self.kernel_size: Tuple[int, int] = kernel_size
+ self.sigma: Tuple[float, float] = sigma
+ self.border_type = border_type
+ self.separable = separable
+
+ def __repr__(self) -> str:
+ return (
+ self.__class__.__name__
+ + '(kernel_size='
+ + str(self.kernel_size)
+ + ', '
+ + 'sigma='
+ + str(self.sigma)
+ + ', '
+ + 'border_type='
+ + self.border_type
+ + 'separable='
+ + str(self.separable)
+ + ')'
+ )
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return gaussian_blur2d(input, self.kernel_size, self.sigma, self.border_type, self.separable)
\ No newline at end of file
diff --git a/propainter/model/canny/kernels.py b/propainter/model/canny/kernels.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae1ee251b8363ba76c7b63c6925a1776c50b7f32
--- /dev/null
+++ b/propainter/model/canny/kernels.py
@@ -0,0 +1,690 @@
+import math
+from math import sqrt
+from typing import List, Optional, Tuple
+
+import torch
+
+
+def normalize_kernel2d(input: torch.Tensor) -> torch.Tensor:
+ r"""Normalize both derivative and smoothing kernel."""
+ if len(input.size()) < 2:
+ raise TypeError(f"input should be at least 2D tensor. Got {input.size()}")
+ norm: torch.Tensor = input.abs().sum(dim=-1).sum(dim=-1)
+ return input / (norm.unsqueeze(-1).unsqueeze(-1))
+
+
+def gaussian(window_size: int, sigma: float) -> torch.Tensor:
+ device, dtype = None, None
+ if isinstance(sigma, torch.Tensor):
+ device, dtype = sigma.device, sigma.dtype
+ x = torch.arange(window_size, device=device, dtype=dtype) - window_size // 2
+ if window_size % 2 == 0:
+ x = x + 0.5
+
+ gauss = torch.exp((-x.pow(2.0) / (2 * sigma**2)).float())
+ return gauss / gauss.sum()
+
+
+def gaussian_discrete_erf(window_size: int, sigma) -> torch.Tensor:
+ r"""Discrete Gaussian by interpolating the error function.
+
+ Adapted from:
+ https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
+ """
+ device = sigma.device if isinstance(sigma, torch.Tensor) else None
+ sigma = torch.as_tensor(sigma, dtype=torch.float, device=device)
+ x = torch.arange(window_size).float() - window_size // 2
+ t = 0.70710678 / torch.abs(sigma)
+ gauss = 0.5 * ((t * (x + 0.5)).erf() - (t * (x - 0.5)).erf())
+ gauss = gauss.clamp(min=0)
+ return gauss / gauss.sum()
+
+
+def _modified_bessel_0(x: torch.Tensor) -> torch.Tensor:
+ r"""Adapted from:
+
+ https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
+ """
+ if torch.abs(x) < 3.75:
+ y = (x / 3.75) * (x / 3.75)
+ return 1.0 + y * (
+ 3.5156229 + y * (3.0899424 + y * (1.2067492 + y * (0.2659732 + y * (0.360768e-1 + y * 0.45813e-2))))
+ )
+ ax = torch.abs(x)
+ y = 3.75 / ax
+ ans = 0.916281e-2 + y * (-0.2057706e-1 + y * (0.2635537e-1 + y * (-0.1647633e-1 + y * 0.392377e-2)))
+ coef = 0.39894228 + y * (0.1328592e-1 + y * (0.225319e-2 + y * (-0.157565e-2 + y * ans)))
+ return (torch.exp(ax) / torch.sqrt(ax)) * coef
+
+
+def _modified_bessel_1(x: torch.Tensor) -> torch.Tensor:
+ r"""adapted from:
+
+ https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
+ """
+ if torch.abs(x) < 3.75:
+ y = (x / 3.75) * (x / 3.75)
+ ans = 0.51498869 + y * (0.15084934 + y * (0.2658733e-1 + y * (0.301532e-2 + y * 0.32411e-3)))
+ return torch.abs(x) * (0.5 + y * (0.87890594 + y * ans))
+ ax = torch.abs(x)
+ y = 3.75 / ax
+ ans = 0.2282967e-1 + y * (-0.2895312e-1 + y * (0.1787654e-1 - y * 0.420059e-2))
+ ans = 0.39894228 + y * (-0.3988024e-1 + y * (-0.362018e-2 + y * (0.163801e-2 + y * (-0.1031555e-1 + y * ans))))
+ ans = ans * torch.exp(ax) / torch.sqrt(ax)
+ return -ans if x < 0.0 else ans
+
+
+def _modified_bessel_i(n: int, x: torch.Tensor) -> torch.Tensor:
+ r"""adapted from:
+
+ https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
+ """
+ if n < 2:
+ raise ValueError("n must be greater than 1.")
+ if x == 0.0:
+ return x
+ device = x.device
+ tox = 2.0 / torch.abs(x)
+ ans = torch.tensor(0.0, device=device)
+ bip = torch.tensor(0.0, device=device)
+ bi = torch.tensor(1.0, device=device)
+ m = int(2 * (n + int(sqrt(40.0 * n))))
+ for j in range(m, 0, -1):
+ bim = bip + float(j) * tox * bi
+ bip = bi
+ bi = bim
+ if abs(bi) > 1.0e10:
+ ans = ans * 1.0e-10
+ bi = bi * 1.0e-10
+ bip = bip * 1.0e-10
+ if j == n:
+ ans = bip
+ ans = ans * _modified_bessel_0(x) / bi
+ return -ans if x < 0.0 and (n % 2) == 1 else ans
+
+
+def gaussian_discrete(window_size, sigma) -> torch.Tensor:
+ r"""Discrete Gaussian kernel based on the modified Bessel functions.
+
+ Adapted from:
+ https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py
+ """
+ device = sigma.device if isinstance(sigma, torch.Tensor) else None
+ sigma = torch.as_tensor(sigma, dtype=torch.float, device=device)
+ sigma2 = sigma * sigma
+ tail = int(window_size // 2)
+ out_pos: List[Optional[torch.Tensor]] = [None] * (tail + 1)
+ out_pos[0] = _modified_bessel_0(sigma2)
+ out_pos[1] = _modified_bessel_1(sigma2)
+ for k in range(2, len(out_pos)):
+ out_pos[k] = _modified_bessel_i(k, sigma2)
+ out = out_pos[:0:-1]
+ out.extend(out_pos)
+ out = torch.stack(out) * torch.exp(sigma2) # type: ignore
+ return out / out.sum() # type: ignore
+
+
+def laplacian_1d(window_size) -> torch.Tensor:
+ r"""One could also use the Laplacian of Gaussian formula to design the filter."""
+
+ filter_1d = torch.ones(window_size)
+ filter_1d[window_size // 2] = 1 - window_size
+ laplacian_1d: torch.Tensor = filter_1d
+ return laplacian_1d
+
+
+def get_box_kernel2d(kernel_size: Tuple[int, int]) -> torch.Tensor:
+ r"""Utility function that returns a box filter."""
+ kx: float = float(kernel_size[0])
+ ky: float = float(kernel_size[1])
+ scale: torch.Tensor = torch.tensor(1.0) / torch.tensor([kx * ky])
+ tmp_kernel: torch.Tensor = torch.ones(1, kernel_size[0], kernel_size[1])
+ return scale.to(tmp_kernel.dtype) * tmp_kernel
+
+
+def get_binary_kernel2d(window_size: Tuple[int, int]) -> torch.Tensor:
+ r"""Create a binary kernel to extract the patches.
+
+ If the window size is HxW will create a (H*W)xHxW kernel.
+ """
+ window_range: int = window_size[0] * window_size[1]
+ kernel: torch.Tensor = torch.zeros(window_range, window_range)
+ for i in range(window_range):
+ kernel[i, i] += 1.0
+ return kernel.view(window_range, 1, window_size[0], window_size[1])
+
+
+def get_sobel_kernel_3x3() -> torch.Tensor:
+ """Utility function that returns a sobel kernel of 3x3."""
+ return torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]])
+
+
+def get_sobel_kernel_5x5_2nd_order() -> torch.Tensor:
+ """Utility function that returns a 2nd order sobel kernel of 5x5."""
+ return torch.tensor(
+ [
+ [-1.0, 0.0, 2.0, 0.0, -1.0],
+ [-4.0, 0.0, 8.0, 0.0, -4.0],
+ [-6.0, 0.0, 12.0, 0.0, -6.0],
+ [-4.0, 0.0, 8.0, 0.0, -4.0],
+ [-1.0, 0.0, 2.0, 0.0, -1.0],
+ ]
+ )
+
+
+def _get_sobel_kernel_5x5_2nd_order_xy() -> torch.Tensor:
+ """Utility function that returns a 2nd order sobel kernel of 5x5."""
+ return torch.tensor(
+ [
+ [-1.0, -2.0, 0.0, 2.0, 1.0],
+ [-2.0, -4.0, 0.0, 4.0, 2.0],
+ [0.0, 0.0, 0.0, 0.0, 0.0],
+ [2.0, 4.0, 0.0, -4.0, -2.0],
+ [1.0, 2.0, 0.0, -2.0, -1.0],
+ ]
+ )
+
+
+def get_diff_kernel_3x3() -> torch.Tensor:
+ """Utility function that returns a first order derivative kernel of 3x3."""
+ return torch.tensor([[-0.0, 0.0, 0.0], [-1.0, 0.0, 1.0], [-0.0, 0.0, 0.0]])
+
+
+def get_diff_kernel3d(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
+ """Utility function that returns a first order derivative kernel of 3x3x3."""
+ kernel: torch.Tensor = torch.tensor(
+ [
+ [
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
+ [[0.0, 0.0, 0.0], [-0.5, 0.0, 0.5], [0.0, 0.0, 0.0]],
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
+ ],
+ [
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
+ [[0.0, -0.5, 0.0], [0.0, 0.0, 0.0], [0.0, 0.5, 0.0]],
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
+ ],
+ [
+ [[0.0, 0.0, 0.0], [0.0, -0.5, 0.0], [0.0, 0.0, 0.0]],
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
+ [[0.0, 0.0, 0.0], [0.0, 0.5, 0.0], [0.0, 0.0, 0.0]],
+ ],
+ ],
+ device=device,
+ dtype=dtype,
+ )
+ return kernel.unsqueeze(1)
+
+
+def get_diff_kernel3d_2nd_order(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
+ """Utility function that returns a first order derivative kernel of 3x3x3."""
+ kernel: torch.Tensor = torch.tensor(
+ [
+ [
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
+ [[0.0, 0.0, 0.0], [1.0, -2.0, 1.0], [0.0, 0.0, 0.0]],
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
+ ],
+ [
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
+ [[0.0, 1.0, 0.0], [0.0, -2.0, 0.0], [0.0, 1.0, 0.0]],
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
+ ],
+ [
+ [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
+ [[0.0, 0.0, 0.0], [0.0, -2.0, 0.0], [0.0, 0.0, 0.0]],
+ [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
+ ],
+ [
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
+ [[1.0, 0.0, -1.0], [0.0, 0.0, 0.0], [-1.0, 0.0, 1.0]],
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
+ ],
+ [
+ [[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, -1.0, 0.0]],
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
+ [[0.0, -1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
+ ],
+ [
+ [[0.0, 0.0, 0.0], [1.0, 0.0, -1.0], [0.0, 0.0, 0.0]],
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
+ [[0.0, 0.0, 0.0], [-1.0, 0.0, 1.0], [0.0, 0.0, 0.0]],
+ ],
+ ],
+ device=device,
+ dtype=dtype,
+ )
+ return kernel.unsqueeze(1)
+
+
+def get_sobel_kernel2d() -> torch.Tensor:
+ kernel_x: torch.Tensor = get_sobel_kernel_3x3()
+ kernel_y: torch.Tensor = kernel_x.transpose(0, 1)
+ return torch.stack([kernel_x, kernel_y])
+
+
+def get_diff_kernel2d() -> torch.Tensor:
+ kernel_x: torch.Tensor = get_diff_kernel_3x3()
+ kernel_y: torch.Tensor = kernel_x.transpose(0, 1)
+ return torch.stack([kernel_x, kernel_y])
+
+
+def get_sobel_kernel2d_2nd_order() -> torch.Tensor:
+ gxx: torch.Tensor = get_sobel_kernel_5x5_2nd_order()
+ gyy: torch.Tensor = gxx.transpose(0, 1)
+ gxy: torch.Tensor = _get_sobel_kernel_5x5_2nd_order_xy()
+ return torch.stack([gxx, gxy, gyy])
+
+
+def get_diff_kernel2d_2nd_order() -> torch.Tensor:
+ gxx: torch.Tensor = torch.tensor([[0.0, 0.0, 0.0], [1.0, -2.0, 1.0], [0.0, 0.0, 0.0]])
+ gyy: torch.Tensor = gxx.transpose(0, 1)
+ gxy: torch.Tensor = torch.tensor([[-1.0, 0.0, 1.0], [0.0, 0.0, 0.0], [1.0, 0.0, -1.0]])
+ return torch.stack([gxx, gxy, gyy])
+
+
+def get_spatial_gradient_kernel2d(mode: str, order: int) -> torch.Tensor:
+ r"""Function that returns kernel for 1st or 2nd order image gradients, using one of the following operators:
+
+ sobel, diff.
+ """
+ if mode not in ['sobel', 'diff']:
+ raise TypeError(
+ "mode should be either sobel\
+ or diff. Got {}".format(
+ mode
+ )
+ )
+ if order not in [1, 2]:
+ raise TypeError(
+ "order should be either 1 or 2\
+ Got {}".format(
+ order
+ )
+ )
+ if mode == 'sobel' and order == 1:
+ kernel: torch.Tensor = get_sobel_kernel2d()
+ elif mode == 'sobel' and order == 2:
+ kernel = get_sobel_kernel2d_2nd_order()
+ elif mode == 'diff' and order == 1:
+ kernel = get_diff_kernel2d()
+ elif mode == 'diff' and order == 2:
+ kernel = get_diff_kernel2d_2nd_order()
+ else:
+ raise NotImplementedError("")
+ return kernel
+
+
+def get_spatial_gradient_kernel3d(mode: str, order: int, device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
+ r"""Function that returns kernel for 1st or 2nd order scale pyramid gradients, using one of the following
+ operators: sobel, diff."""
+ if mode not in ['sobel', 'diff']:
+ raise TypeError(
+ "mode should be either sobel\
+ or diff. Got {}".format(
+ mode
+ )
+ )
+ if order not in [1, 2]:
+ raise TypeError(
+ "order should be either 1 or 2\
+ Got {}".format(
+ order
+ )
+ )
+ if mode == 'sobel':
+ raise NotImplementedError("Sobel kernel for 3d gradient is not implemented yet")
+ if mode == 'diff' and order == 1:
+ kernel = get_diff_kernel3d(device, dtype)
+ elif mode == 'diff' and order == 2:
+ kernel = get_diff_kernel3d_2nd_order(device, dtype)
+ else:
+ raise NotImplementedError("")
+ return kernel
+
+
+def get_gaussian_kernel1d(kernel_size: int, sigma: float, force_even: bool = False) -> torch.Tensor:
+ r"""Function that returns Gaussian filter coefficients.
+
+ Args:
+ kernel_size: filter size. It should be odd and positive.
+ sigma: gaussian standard deviation.
+ force_even: overrides requirement for odd kernel size.
+
+ Returns:
+ 1D tensor with gaussian filter coefficients.
+
+ Shape:
+ - Output: :math:`(\text{kernel_size})`
+
+ Examples:
+
+ >>> get_gaussian_kernel1d(3, 2.5)
+ tensor([0.3243, 0.3513, 0.3243])
+
+ >>> get_gaussian_kernel1d(5, 1.5)
+ tensor([0.1201, 0.2339, 0.2921, 0.2339, 0.1201])
+ """
+ if not isinstance(kernel_size, int) or ((kernel_size % 2 == 0) and not force_even) or (kernel_size <= 0):
+ raise TypeError("kernel_size must be an odd positive integer. " "Got {}".format(kernel_size))
+ window_1d: torch.Tensor = gaussian(kernel_size, sigma)
+ return window_1d
+
+
+def get_gaussian_discrete_kernel1d(kernel_size: int, sigma: float, force_even: bool = False) -> torch.Tensor:
+ r"""Function that returns Gaussian filter coefficients based on the modified Bessel functions. Adapted from:
+ https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py.
+
+ Args:
+ kernel_size: filter size. It should be odd and positive.
+ sigma: gaussian standard deviation.
+ force_even: overrides requirement for odd kernel size.
+
+ Returns:
+ 1D tensor with gaussian filter coefficients.
+
+ Shape:
+ - Output: :math:`(\text{kernel_size})`
+
+ Examples:
+
+ >>> get_gaussian_discrete_kernel1d(3, 2.5)
+ tensor([0.3235, 0.3531, 0.3235])
+
+ >>> get_gaussian_discrete_kernel1d(5, 1.5)
+ tensor([0.1096, 0.2323, 0.3161, 0.2323, 0.1096])
+ """
+ if not isinstance(kernel_size, int) or ((kernel_size % 2 == 0) and not force_even) or (kernel_size <= 0):
+ raise TypeError("kernel_size must be an odd positive integer. " "Got {}".format(kernel_size))
+ window_1d = gaussian_discrete(kernel_size, sigma)
+ return window_1d
+
+
+def get_gaussian_erf_kernel1d(kernel_size: int, sigma: float, force_even: bool = False) -> torch.Tensor:
+ r"""Function that returns Gaussian filter coefficients by interpolating the error function, adapted from:
+ https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py.
+
+ Args:
+ kernel_size: filter size. It should be odd and positive.
+ sigma: gaussian standard deviation.
+ force_even: overrides requirement for odd kernel size.
+
+ Returns:
+ 1D tensor with gaussian filter coefficients.
+
+ Shape:
+ - Output: :math:`(\text{kernel_size})`
+
+ Examples:
+
+ >>> get_gaussian_erf_kernel1d(3, 2.5)
+ tensor([0.3245, 0.3511, 0.3245])
+
+ >>> get_gaussian_erf_kernel1d(5, 1.5)
+ tensor([0.1226, 0.2331, 0.2887, 0.2331, 0.1226])
+ """
+ if not isinstance(kernel_size, int) or ((kernel_size % 2 == 0) and not force_even) or (kernel_size <= 0):
+ raise TypeError("kernel_size must be an odd positive integer. " "Got {}".format(kernel_size))
+ window_1d = gaussian_discrete_erf(kernel_size, sigma)
+ return window_1d
+
+
+def get_gaussian_kernel2d(
+ kernel_size: Tuple[int, int], sigma: Tuple[float, float], force_even: bool = False
+) -> torch.Tensor:
+ r"""Function that returns Gaussian filter matrix coefficients.
+
+ Args:
+ kernel_size: filter sizes in the x and y direction.
+ Sizes should be odd and positive.
+ sigma: gaussian standard deviation in the x and y
+ direction.
+ force_even: overrides requirement for odd kernel size.
+
+ Returns:
+ 2D tensor with gaussian filter matrix coefficients.
+
+ Shape:
+ - Output: :math:`(\text{kernel_size}_x, \text{kernel_size}_y)`
+
+ Examples:
+ >>> get_gaussian_kernel2d((3, 3), (1.5, 1.5))
+ tensor([[0.0947, 0.1183, 0.0947],
+ [0.1183, 0.1478, 0.1183],
+ [0.0947, 0.1183, 0.0947]])
+ >>> get_gaussian_kernel2d((3, 5), (1.5, 1.5))
+ tensor([[0.0370, 0.0720, 0.0899, 0.0720, 0.0370],
+ [0.0462, 0.0899, 0.1123, 0.0899, 0.0462],
+ [0.0370, 0.0720, 0.0899, 0.0720, 0.0370]])
+ """
+ if not isinstance(kernel_size, tuple) or len(kernel_size) != 2:
+ raise TypeError(f"kernel_size must be a tuple of length two. Got {kernel_size}")
+ if not isinstance(sigma, tuple) or len(sigma) != 2:
+ raise TypeError(f"sigma must be a tuple of length two. Got {sigma}")
+ ksize_x, ksize_y = kernel_size
+ sigma_x, sigma_y = sigma
+ kernel_x: torch.Tensor = get_gaussian_kernel1d(ksize_x, sigma_x, force_even)
+ kernel_y: torch.Tensor = get_gaussian_kernel1d(ksize_y, sigma_y, force_even)
+ kernel_2d: torch.Tensor = torch.matmul(kernel_x.unsqueeze(-1), kernel_y.unsqueeze(-1).t())
+ return kernel_2d
+
+
+def get_laplacian_kernel1d(kernel_size: int) -> torch.Tensor:
+ r"""Function that returns the coefficients of a 1D Laplacian filter.
+
+ Args:
+ kernel_size: filter size. It should be odd and positive.
+
+ Returns:
+ 1D tensor with laplacian filter coefficients.
+
+ Shape:
+ - Output: math:`(\text{kernel_size})`
+
+ Examples:
+ >>> get_laplacian_kernel1d(3)
+ tensor([ 1., -2., 1.])
+ >>> get_laplacian_kernel1d(5)
+ tensor([ 1., 1., -4., 1., 1.])
+ """
+ if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or kernel_size <= 0:
+ raise TypeError(f"ksize must be an odd positive integer. Got {kernel_size}")
+ window_1d: torch.Tensor = laplacian_1d(kernel_size)
+ return window_1d
+
+
+def get_laplacian_kernel2d(kernel_size: int) -> torch.Tensor:
+ r"""Function that returns Gaussian filter matrix coefficients.
+
+ Args:
+ kernel_size: filter size should be odd.
+
+ Returns:
+ 2D tensor with laplacian filter matrix coefficients.
+
+ Shape:
+ - Output: :math:`(\text{kernel_size}_x, \text{kernel_size}_y)`
+
+ Examples:
+ >>> get_laplacian_kernel2d(3)
+ tensor([[ 1., 1., 1.],
+ [ 1., -8., 1.],
+ [ 1., 1., 1.]])
+ >>> get_laplacian_kernel2d(5)
+ tensor([[ 1., 1., 1., 1., 1.],
+ [ 1., 1., 1., 1., 1.],
+ [ 1., 1., -24., 1., 1.],
+ [ 1., 1., 1., 1., 1.],
+ [ 1., 1., 1., 1., 1.]])
+ """
+ if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or kernel_size <= 0:
+ raise TypeError(f"ksize must be an odd positive integer. Got {kernel_size}")
+
+ kernel = torch.ones((kernel_size, kernel_size))
+ mid = kernel_size // 2
+ kernel[mid, mid] = 1 - kernel_size**2
+ kernel_2d: torch.Tensor = kernel
+ return kernel_2d
+
+
+def get_pascal_kernel_2d(kernel_size: int, norm: bool = True) -> torch.Tensor:
+ """Generate pascal filter kernel by kernel size.
+
+ Args:
+ kernel_size: height and width of the kernel.
+ norm: if to normalize the kernel or not. Default: True.
+
+ Returns:
+ kernel shaped as :math:`(kernel_size, kernel_size)`
+
+ Examples:
+ >>> get_pascal_kernel_2d(1)
+ tensor([[1.]])
+ >>> get_pascal_kernel_2d(4)
+ tensor([[0.0156, 0.0469, 0.0469, 0.0156],
+ [0.0469, 0.1406, 0.1406, 0.0469],
+ [0.0469, 0.1406, 0.1406, 0.0469],
+ [0.0156, 0.0469, 0.0469, 0.0156]])
+ >>> get_pascal_kernel_2d(4, norm=False)
+ tensor([[1., 3., 3., 1.],
+ [3., 9., 9., 3.],
+ [3., 9., 9., 3.],
+ [1., 3., 3., 1.]])
+ """
+ a = get_pascal_kernel_1d(kernel_size)
+
+ filt = a[:, None] * a[None, :]
+ if norm:
+ filt = filt / torch.sum(filt)
+ return filt
+
+
+def get_pascal_kernel_1d(kernel_size: int, norm: bool = False) -> torch.Tensor:
+ """Generate Yang Hui triangle (Pascal's triangle) by a given number.
+
+ Args:
+ kernel_size: height and width of the kernel.
+ norm: if to normalize the kernel or not. Default: False.
+
+ Returns:
+ kernel shaped as :math:`(kernel_size,)`
+
+ Examples:
+ >>> get_pascal_kernel_1d(1)
+ tensor([1.])
+ >>> get_pascal_kernel_1d(2)
+ tensor([1., 1.])
+ >>> get_pascal_kernel_1d(3)
+ tensor([1., 2., 1.])
+ >>> get_pascal_kernel_1d(4)
+ tensor([1., 3., 3., 1.])
+ >>> get_pascal_kernel_1d(5)
+ tensor([1., 4., 6., 4., 1.])
+ >>> get_pascal_kernel_1d(6)
+ tensor([ 1., 5., 10., 10., 5., 1.])
+ """
+ pre: List[float] = []
+ cur: List[float] = []
+ for i in range(kernel_size):
+ cur = [1.0] * (i + 1)
+
+ for j in range(1, i // 2 + 1):
+ value = pre[j - 1] + pre[j]
+ cur[j] = value
+ if i != 2 * j:
+ cur[-j - 1] = value
+ pre = cur
+
+ out = torch.as_tensor(cur)
+ if norm:
+ out = out / torch.sum(out)
+ return out
+
+
+def get_canny_nms_kernel(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
+ """Utility function that returns 3x3 kernels for the Canny Non-maximal suppression."""
+ kernel: torch.Tensor = torch.tensor(
+ [
+ [[0.0, 0.0, 0.0], [0.0, 1.0, -1.0], [0.0, 0.0, 0.0]],
+ [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]],
+ [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, -1.0, 0.0]],
+ [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]],
+ [[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
+ [[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
+ [[0.0, -1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
+ [[0.0, 0.0, -1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]],
+ ],
+ device=device,
+ dtype=dtype,
+ )
+ return kernel.unsqueeze(1)
+
+
+def get_hysteresis_kernel(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
+ """Utility function that returns the 3x3 kernels for the Canny hysteresis."""
+ kernel: torch.Tensor = torch.tensor(
+ [
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 0.0]],
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 1.0]],
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]],
+ [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0]],
+ [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
+ [[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
+ [[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
+ [[0.0, 0.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
+ ],
+ device=device,
+ dtype=dtype,
+ )
+ return kernel.unsqueeze(1)
+
+
+def get_hanning_kernel1d(kernel_size: int, device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
+ r"""Returns Hanning (also known as Hann) kernel, used in signal processing and KCF tracker.
+
+ .. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right)
+ \\qquad 0 \\leq n \\leq M-1
+
+ See further in numpy docs https://numpy.org/doc/stable/reference/generated/numpy.hanning.html
+
+ Args:
+ kernel_size: The size the of the kernel. It should be positive.
+
+ Returns:
+ 1D tensor with Hanning filter coefficients.
+ .. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right)
+
+ Shape:
+ - Output: math:`(\text{kernel_size})`
+
+ Examples:
+ >>> get_hanning_kernel1d(4)
+ tensor([0.0000, 0.7500, 0.7500, 0.0000])
+ """
+ if not isinstance(kernel_size, int) or kernel_size <= 2:
+ raise TypeError(f"ksize must be an positive integer > 2. Got {kernel_size}")
+
+ x: torch.Tensor = torch.arange(kernel_size, device=device, dtype=dtype)
+ x = 0.5 - 0.5 * torch.cos(2.0 * math.pi * x / float(kernel_size - 1))
+ return x
+
+
+def get_hanning_kernel2d(kernel_size: Tuple[int, int], device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor:
+ r"""Returns 2d Hanning kernel, used in signal processing and KCF tracker.
+
+ Args:
+ kernel_size: The size of the kernel for the filter. It should be positive.
+
+ Returns:
+ 2D tensor with Hanning filter coefficients.
+ .. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right)
+
+ Shape:
+ - Output: math:`(\text{kernel_size[0], kernel_size[1]})`
+ """
+ if kernel_size[0] <= 2 or kernel_size[1] <= 2:
+ raise TypeError(f"ksize must be an tuple of positive integers > 2. Got {kernel_size}")
+ ky: torch.Tensor = get_hanning_kernel1d(kernel_size[0], device, dtype)[None].T
+ kx: torch.Tensor = get_hanning_kernel1d(kernel_size[1], device, dtype)[None]
+ kernel2d = ky @ kx
+ return kernel2d
\ No newline at end of file
diff --git a/propainter/model/canny/sobel.py b/propainter/model/canny/sobel.py
new file mode 100644
index 0000000000000000000000000000000000000000..d780c5c4a22bb6403122a292b6d30fa022f262e8
--- /dev/null
+++ b/propainter/model/canny/sobel.py
@@ -0,0 +1,263 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .kernels import get_spatial_gradient_kernel2d, get_spatial_gradient_kernel3d, normalize_kernel2d
+
+
+def spatial_gradient(input: torch.Tensor, mode: str = 'sobel', order: int = 1, normalized: bool = True) -> torch.Tensor:
+ r"""Compute the first order image derivative in both x and y using a Sobel operator.
+
+ .. image:: _static/img/spatial_gradient.png
+
+ Args:
+ input: input image tensor with shape :math:`(B, C, H, W)`.
+ mode: derivatives modality, can be: `sobel` or `diff`.
+ order: the order of the derivatives.
+ normalized: whether the output is normalized.
+
+ Return:
+ the derivatives of the input feature map. with shape :math:`(B, C, 2, H, W)`.
+
+ .. note::
+ See a working example `here `__.
+
+ Examples:
+ >>> input = torch.rand(1, 3, 4, 4)
+ >>> output = spatial_gradient(input) # 1x3x2x4x4
+ >>> output.shape
+ torch.Size([1, 3, 2, 4, 4])
+ """
+ if not isinstance(input, torch.Tensor):
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
+
+ if not len(input.shape) == 4:
+ raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
+ # allocate kernel
+ kernel: torch.Tensor = get_spatial_gradient_kernel2d(mode, order)
+ if normalized:
+ kernel = normalize_kernel2d(kernel)
+
+ # prepare kernel
+ b, c, h, w = input.shape
+ tmp_kernel: torch.Tensor = kernel.to(input).detach()
+ tmp_kernel = tmp_kernel.unsqueeze(1).unsqueeze(1)
+
+ # convolve input tensor with sobel kernel
+ kernel_flip: torch.Tensor = tmp_kernel.flip(-3)
+
+ # Pad with "replicate for spatial dims, but with zeros for channel
+ spatial_pad = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2]
+ out_channels: int = 3 if order == 2 else 2
+ padded_inp: torch.Tensor = F.pad(input.reshape(b * c, 1, h, w), spatial_pad, 'replicate')[:, :, None]
+
+ return F.conv3d(padded_inp, kernel_flip, padding=0).view(b, c, out_channels, h, w)
+
+
+def spatial_gradient3d(input: torch.Tensor, mode: str = 'diff', order: int = 1) -> torch.Tensor:
+ r"""Compute the first and second order volume derivative in x, y and d using a diff operator.
+
+ Args:
+ input: input features tensor with shape :math:`(B, C, D, H, W)`.
+ mode: derivatives modality, can be: `sobel` or `diff`.
+ order: the order of the derivatives.
+
+ Return:
+ the spatial gradients of the input feature map with shape math:`(B, C, 3, D, H, W)`
+ or :math:`(B, C, 6, D, H, W)`.
+
+ Examples:
+ >>> input = torch.rand(1, 4, 2, 4, 4)
+ >>> output = spatial_gradient3d(input)
+ >>> output.shape
+ torch.Size([1, 4, 3, 2, 4, 4])
+ """
+ if not isinstance(input, torch.Tensor):
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
+
+ if not len(input.shape) == 5:
+ raise ValueError(f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}")
+ b, c, d, h, w = input.shape
+ dev = input.device
+ dtype = input.dtype
+ if (mode == 'diff') and (order == 1):
+ # we go for the special case implementation due to conv3d bad speed
+ x: torch.Tensor = F.pad(input, 6 * [1], 'replicate')
+ center = slice(1, -1)
+ left = slice(0, -2)
+ right = slice(2, None)
+ out = torch.empty(b, c, 3, d, h, w, device=dev, dtype=dtype)
+ out[..., 0, :, :, :] = x[..., center, center, right] - x[..., center, center, left]
+ out[..., 1, :, :, :] = x[..., center, right, center] - x[..., center, left, center]
+ out[..., 2, :, :, :] = x[..., right, center, center] - x[..., left, center, center]
+ out = 0.5 * out
+ else:
+ # prepare kernel
+ # allocate kernel
+ kernel: torch.Tensor = get_spatial_gradient_kernel3d(mode, order)
+
+ tmp_kernel: torch.Tensor = kernel.to(input).detach()
+ tmp_kernel = tmp_kernel.repeat(c, 1, 1, 1, 1)
+
+ # convolve input tensor with grad kernel
+ kernel_flip: torch.Tensor = tmp_kernel.flip(-3)
+
+ # Pad with "replicate for spatial dims, but with zeros for channel
+ spatial_pad = [
+ kernel.size(2) // 2,
+ kernel.size(2) // 2,
+ kernel.size(3) // 2,
+ kernel.size(3) // 2,
+ kernel.size(4) // 2,
+ kernel.size(4) // 2,
+ ]
+ out_ch: int = 6 if order == 2 else 3
+ out = F.conv3d(F.pad(input, spatial_pad, 'replicate'), kernel_flip, padding=0, groups=c).view(
+ b, c, out_ch, d, h, w
+ )
+ return out
+
+
+def sobel(input: torch.Tensor, normalized: bool = True, eps: float = 1e-6) -> torch.Tensor:
+ r"""Compute the Sobel operator and returns the magnitude per channel.
+
+ .. image:: _static/img/sobel.png
+
+ Args:
+ input: the input image with shape :math:`(B,C,H,W)`.
+ normalized: if True, L1 norm of the kernel is set to 1.
+ eps: regularization number to avoid NaN during backprop.
+
+ Return:
+ the sobel edge gradient magnitudes map with shape :math:`(B,C,H,W)`.
+
+ .. note::
+ See a working example `here `__.
+
+ Example:
+ >>> input = torch.rand(1, 3, 4, 4)
+ >>> output = sobel(input) # 1x3x4x4
+ >>> output.shape
+ torch.Size([1, 3, 4, 4])
+ """
+ if not isinstance(input, torch.Tensor):
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}")
+
+ if not len(input.shape) == 4:
+ raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}")
+
+ # comput the x/y gradients
+ edges: torch.Tensor = spatial_gradient(input, normalized=normalized)
+
+ # unpack the edges
+ gx: torch.Tensor = edges[:, :, 0]
+ gy: torch.Tensor = edges[:, :, 1]
+
+ # compute gradient maginitude
+ magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps)
+
+ return magnitude
+
+
+class SpatialGradient(nn.Module):
+ r"""Compute the first order image derivative in both x and y using a Sobel operator.
+
+ Args:
+ mode: derivatives modality, can be: `sobel` or `diff`.
+ order: the order of the derivatives.
+ normalized: whether the output is normalized.
+
+ Return:
+ the sobel edges of the input feature map.
+
+ Shape:
+ - Input: :math:`(B, C, H, W)`
+ - Output: :math:`(B, C, 2, H, W)`
+
+ Examples:
+ >>> input = torch.rand(1, 3, 4, 4)
+ >>> output = SpatialGradient()(input) # 1x3x2x4x4
+ """
+
+ def __init__(self, mode: str = 'sobel', order: int = 1, normalized: bool = True) -> None:
+ super().__init__()
+ self.normalized: bool = normalized
+ self.order: int = order
+ self.mode: str = mode
+
+ def __repr__(self) -> str:
+ return (
+ self.__class__.__name__ + '('
+ 'order=' + str(self.order) + ', ' + 'normalized=' + str(self.normalized) + ', ' + 'mode=' + self.mode + ')'
+ )
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return spatial_gradient(input, self.mode, self.order, self.normalized)
+
+
+class SpatialGradient3d(nn.Module):
+ r"""Compute the first and second order volume derivative in x, y and d using a diff operator.
+
+ Args:
+ mode: derivatives modality, can be: `sobel` or `diff`.
+ order: the order of the derivatives.
+
+ Return:
+ the spatial gradients of the input feature map.
+
+ Shape:
+ - Input: :math:`(B, C, D, H, W)`. D, H, W are spatial dimensions, gradient is calculated w.r.t to them.
+ - Output: :math:`(B, C, 3, D, H, W)` or :math:`(B, C, 6, D, H, W)`
+
+ Examples:
+ >>> input = torch.rand(1, 4, 2, 4, 4)
+ >>> output = SpatialGradient3d()(input)
+ >>> output.shape
+ torch.Size([1, 4, 3, 2, 4, 4])
+ """
+
+ def __init__(self, mode: str = 'diff', order: int = 1) -> None:
+ super().__init__()
+ self.order: int = order
+ self.mode: str = mode
+ self.kernel = get_spatial_gradient_kernel3d(mode, order)
+ return
+
+ def __repr__(self) -> str:
+ return self.__class__.__name__ + '(' 'order=' + str(self.order) + ', ' + 'mode=' + self.mode + ')'
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore
+ return spatial_gradient3d(input, self.mode, self.order)
+
+
+class Sobel(nn.Module):
+ r"""Compute the Sobel operator and returns the magnitude per channel.
+
+ Args:
+ normalized: if True, L1 norm of the kernel is set to 1.
+ eps: regularization number to avoid NaN during backprop.
+
+ Return:
+ the sobel edge gradient magnitudes map.
+
+ Shape:
+ - Input: :math:`(B, C, H, W)`
+ - Output: :math:`(B, C, H, W)`
+
+ Examples:
+ >>> input = torch.rand(1, 3, 4, 4)
+ >>> output = Sobel()(input) # 1x3x4x4
+ """
+
+ def __init__(self, normalized: bool = True, eps: float = 1e-6) -> None:
+ super().__init__()
+ self.normalized: bool = normalized
+ self.eps: float = eps
+
+ def __repr__(self) -> str:
+ return self.__class__.__name__ + '(' 'normalized=' + str(self.normalized) + ')'
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ return sobel(input, self.normalized, self.eps)
\ No newline at end of file
diff --git a/propainter/model/misc.py b/propainter/model/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..43b849902245dd338a36f4f4ff09e33425365af6
--- /dev/null
+++ b/propainter/model/misc.py
@@ -0,0 +1,131 @@
+import os
+import re
+import random
+import time
+import torch
+import torch.nn as nn
+import logging
+import numpy as np
+from os import path as osp
+
+def constant_init(module, val, bias=0):
+ if hasattr(module, 'weight') and module.weight is not None:
+ nn.init.constant_(module.weight, val)
+ if hasattr(module, 'bias') and module.bias is not None:
+ nn.init.constant_(module.bias, bias)
+
+initialized_logger = {}
+def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
+ """Get the root logger.
+ The logger will be initialized if it has not been initialized. By default a
+ StreamHandler will be added. If `log_file` is specified, a FileHandler will
+ also be added.
+ Args:
+ logger_name (str): root logger name. Default: 'basicsr'.
+ log_file (str | None): The log filename. If specified, a FileHandler
+ will be added to the root logger.
+ log_level (int): The root logger level. Note that only the process of
+ rank 0 is affected, while other processes will set the level to
+ "Error" and be silent most of the time.
+ Returns:
+ logging.Logger: The root logger.
+ """
+ logger = logging.getLogger(logger_name)
+ # if the logger has been initialized, just return it
+ if logger_name in initialized_logger:
+ return logger
+
+ format_str = '%(asctime)s %(levelname)s: %(message)s'
+ stream_handler = logging.StreamHandler()
+ stream_handler.setFormatter(logging.Formatter(format_str))
+ logger.addHandler(stream_handler)
+ logger.propagate = False
+
+ if log_file is not None:
+ logger.setLevel(log_level)
+ # add file handler
+ # file_handler = logging.FileHandler(log_file, 'w')
+ file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log
+ file_handler.setFormatter(logging.Formatter(format_str))
+ file_handler.setLevel(log_level)
+ logger.addHandler(file_handler)
+ initialized_logger[logger_name] = True
+ return logger
+
+
+IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\
+ torch.__version__)[0][:3])] >= [1, 12, 0]
+
+def gpu_is_available():
+ if IS_HIGH_VERSION:
+ if torch.backends.mps.is_available():
+ return True
+ return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False
+
+def get_device(gpu_id=None):
+ if gpu_id is None:
+ gpu_str = ''
+ elif isinstance(gpu_id, int):
+ gpu_str = f':{gpu_id}'
+ else:
+ raise TypeError('Input should be int value.')
+
+ if IS_HIGH_VERSION:
+ if torch.backends.mps.is_available():
+ return torch.device('mps'+gpu_str)
+ return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu')
+
+
+def set_random_seed(seed):
+ """Set random seeds."""
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
+def get_time_str():
+ return time.strftime('%Y%m%d_%H%M%S', time.localtime())
+
+
+def scandir(dir_path, suffix=None, recursive=False, full_path=False):
+ """Scan a directory to find the interested files.
+
+ Args:
+ dir_path (str): Path of the directory.
+ suffix (str | tuple(str), optional): File suffix that we are
+ interested in. Default: None.
+ recursive (bool, optional): If set to True, recursively scan the
+ directory. Default: False.
+ full_path (bool, optional): If set to True, include the dir_path.
+ Default: False.
+
+ Returns:
+ A generator for all the interested files with relative pathes.
+ """
+
+ if (suffix is not None) and not isinstance(suffix, (str, tuple)):
+ raise TypeError('"suffix" must be a string or tuple of strings')
+
+ root = dir_path
+
+ def _scandir(dir_path, suffix, recursive):
+ for entry in os.scandir(dir_path):
+ if not entry.name.startswith('.') and entry.is_file():
+ if full_path:
+ return_path = entry.path
+ else:
+ return_path = osp.relpath(entry.path, root)
+
+ if suffix is None:
+ yield return_path
+ elif return_path.endswith(suffix):
+ yield return_path
+ else:
+ if recursive:
+ yield from _scandir(entry.path, suffix=suffix, recursive=recursive)
+ else:
+ continue
+
+ return _scandir(dir_path, suffix=suffix, recursive=recursive)
\ No newline at end of file
diff --git a/propainter/model/modules/base_module.py b/propainter/model/modules/base_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..b28c094308dd4d1bbb62dd75e02e937e2c9ddf14
--- /dev/null
+++ b/propainter/model/modules/base_module.py
@@ -0,0 +1,131 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from functools import reduce
+
+class BaseNetwork(nn.Module):
+ def __init__(self):
+ super(BaseNetwork, self).__init__()
+
+ def print_network(self):
+ if isinstance(self, list):
+ self = self[0]
+ num_params = 0
+ for param in self.parameters():
+ num_params += param.numel()
+ print(
+ 'Network [%s] was created. Total number of parameters: %.1f million. '
+ 'To see the architecture, do print(network).' %
+ (type(self).__name__, num_params / 1000000))
+
+ def init_weights(self, init_type='normal', gain=0.02):
+ '''
+ initialize network's weights
+ init_type: normal | xavier | kaiming | orthogonal
+ https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39
+ '''
+ def init_func(m):
+ classname = m.__class__.__name__
+ if classname.find('InstanceNorm2d') != -1:
+ if hasattr(m, 'weight') and m.weight is not None:
+ nn.init.constant_(m.weight.data, 1.0)
+ if hasattr(m, 'bias') and m.bias is not None:
+ nn.init.constant_(m.bias.data, 0.0)
+ elif hasattr(m, 'weight') and (classname.find('Conv') != -1
+ or classname.find('Linear') != -1):
+ if init_type == 'normal':
+ nn.init.normal_(m.weight.data, 0.0, gain)
+ elif init_type == 'xavier':
+ nn.init.xavier_normal_(m.weight.data, gain=gain)
+ elif init_type == 'xavier_uniform':
+ nn.init.xavier_uniform_(m.weight.data, gain=1.0)
+ elif init_type == 'kaiming':
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
+ elif init_type == 'orthogonal':
+ nn.init.orthogonal_(m.weight.data, gain=gain)
+ elif init_type == 'none': # uses pytorch's default init method
+ m.reset_parameters()
+ else:
+ raise NotImplementedError(
+ 'initialization method [%s] is not implemented' %
+ init_type)
+ if hasattr(m, 'bias') and m.bias is not None:
+ nn.init.constant_(m.bias.data, 0.0)
+
+ self.apply(init_func)
+
+ # propagate to children
+ for m in self.children():
+ if hasattr(m, 'init_weights'):
+ m.init_weights(init_type, gain)
+
+
+class Vec2Feat(nn.Module):
+ def __init__(self, channel, hidden, kernel_size, stride, padding):
+ super(Vec2Feat, self).__init__()
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
+ c_out = reduce((lambda x, y: x * y), kernel_size) * channel
+ self.embedding = nn.Linear(hidden, c_out)
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.bias_conv = nn.Conv2d(channel,
+ channel,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, t, output_size):
+ b_, _, _, _, c_ = x.shape
+ x = x.view(b_, -1, c_)
+ feat = self.embedding(x)
+ b, _, c = feat.size()
+ feat = feat.view(b * t, -1, c).permute(0, 2, 1)
+ feat = F.fold(feat,
+ output_size=output_size,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=self.padding)
+ feat = self.bias_conv(feat)
+ return feat
+
+
+class FusionFeedForward(nn.Module):
+ def __init__(self, dim, hidden_dim=1960, t2t_params=None):
+ super(FusionFeedForward, self).__init__()
+ # We set hidden_dim as a default to 1960
+ self.fc1 = nn.Sequential(nn.Linear(dim, hidden_dim))
+ self.fc2 = nn.Sequential(nn.GELU(), nn.Linear(hidden_dim, dim))
+ assert t2t_params is not None
+ self.t2t_params = t2t_params
+ self.kernel_shape = reduce((lambda x, y: x * y), t2t_params['kernel_size']) # 49
+
+ def forward(self, x, output_size):
+ n_vecs = 1
+ for i, d in enumerate(self.t2t_params['kernel_size']):
+ n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] -
+ (d - 1) - 1) / self.t2t_params['stride'][i] + 1)
+
+ x = self.fc1(x)
+ b, n, c = x.size()
+ normalizer = x.new_ones(b, n, self.kernel_shape).view(-1, n_vecs, self.kernel_shape).permute(0, 2, 1)
+ normalizer = F.fold(normalizer,
+ output_size=output_size,
+ kernel_size=self.t2t_params['kernel_size'],
+ padding=self.t2t_params['padding'],
+ stride=self.t2t_params['stride'])
+
+ x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1),
+ output_size=output_size,
+ kernel_size=self.t2t_params['kernel_size'],
+ padding=self.t2t_params['padding'],
+ stride=self.t2t_params['stride'])
+
+ x = F.unfold(x / normalizer,
+ kernel_size=self.t2t_params['kernel_size'],
+ padding=self.t2t_params['padding'],
+ stride=self.t2t_params['stride']).permute(
+ 0, 2, 1).contiguous().view(b, n, c)
+ x = self.fc2(x)
+ return x
diff --git a/propainter/model/modules/deformconv.py b/propainter/model/modules/deformconv.py
new file mode 100644
index 0000000000000000000000000000000000000000..89cb31b3d80bd69704a380930964db6fb29a6bbe
--- /dev/null
+++ b/propainter/model/modules/deformconv.py
@@ -0,0 +1,54 @@
+import torch
+import torch.nn as nn
+from torch.nn import init as init
+from torch.nn.modules.utils import _pair, _single
+import math
+
+class ModulatedDeformConv2d(nn.Module):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deform_groups=1,
+ bias=True):
+ super(ModulatedDeformConv2d, self).__init__()
+
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.kernel_size = _pair(kernel_size)
+ self.stride = stride
+ self.padding = padding
+ self.dilation = dilation
+ self.groups = groups
+ self.deform_groups = deform_groups
+ self.with_bias = bias
+ # enable compatibility with nn.Conv2d
+ self.transposed = False
+ self.output_padding = _single(0)
+
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.register_parameter('bias', None)
+ self.init_weights()
+
+ def init_weights(self):
+ n = self.in_channels
+ for k in self.kernel_size:
+ n *= k
+ stdv = 1. / math.sqrt(n)
+ self.weight.data.uniform_(-stdv, stdv)
+ if self.bias is not None:
+ self.bias.data.zero_()
+
+ if hasattr(self, 'conv_offset'):
+ self.conv_offset.weight.data.zero_()
+ self.conv_offset.bias.data.zero_()
+
+ def forward(self, x, offset, mask):
+ pass
\ No newline at end of file
diff --git a/propainter/model/modules/flow_comp_raft.py b/propainter/model/modules/flow_comp_raft.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7f09eba83186726fcc112b7b80172dcd74d5d69
--- /dev/null
+++ b/propainter/model/modules/flow_comp_raft.py
@@ -0,0 +1,270 @@
+import argparse
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+try:
+ from RAFT import RAFT
+ from model.modules.flow_loss_utils import flow_warp, ternary_loss2
+except:
+ from propainter.RAFT import RAFT
+ from propainter.model.modules.flow_loss_utils import flow_warp, ternary_loss2
+
+
+
+def initialize_RAFT(model_path='weights/raft-things.pth', device='cuda'):
+ """Initializes the RAFT model.
+ """
+ args = argparse.ArgumentParser()
+ args.raft_model = model_path
+ args.small = False
+ args.mixed_precision = False
+ args.alternate_corr = False
+ model = torch.nn.DataParallel(RAFT(args))
+ model.load_state_dict(torch.load(args.raft_model, map_location='cpu'))
+ model = model.module
+
+ model.to(device)
+
+ return model
+
+
+class RAFT_bi(nn.Module):
+ """Flow completion loss"""
+ def __init__(self, model_path='weights/raft-things.pth', device='cuda'):
+ super().__init__()
+ self.fix_raft = initialize_RAFT(model_path, device=device)
+
+ for p in self.fix_raft.parameters():
+ p.requires_grad = False
+
+ self.l1_criterion = nn.L1Loss()
+ self.eval()
+
+ def forward(self, gt_local_frames, iters=20):
+ b, l_t, c, h, w = gt_local_frames.size()
+ # print(gt_local_frames.shape)
+
+ with torch.no_grad():
+ gtlf_1 = gt_local_frames[:, :-1, :, :, :].reshape(-1, c, h, w)
+ gtlf_2 = gt_local_frames[:, 1:, :, :, :].reshape(-1, c, h, w)
+ # print(gtlf_1.shape)
+
+ _, gt_flows_forward = self.fix_raft(gtlf_1, gtlf_2, iters=iters, test_mode=True)
+ _, gt_flows_backward = self.fix_raft(gtlf_2, gtlf_1, iters=iters, test_mode=True)
+
+
+ gt_flows_forward = gt_flows_forward.view(b, l_t-1, 2, h, w)
+ gt_flows_backward = gt_flows_backward.view(b, l_t-1, 2, h, w)
+
+ return gt_flows_forward, gt_flows_backward
+
+
+##################################################################################
+def smoothness_loss(flow, cmask):
+ delta_u, delta_v, mask = smoothness_deltas(flow)
+ loss_u = charbonnier_loss(delta_u, cmask)
+ loss_v = charbonnier_loss(delta_v, cmask)
+ return loss_u + loss_v
+
+
+def smoothness_deltas(flow):
+ """
+ flow: [b, c, h, w]
+ """
+ mask_x = create_mask(flow, [[0, 0], [0, 1]])
+ mask_y = create_mask(flow, [[0, 1], [0, 0]])
+ mask = torch.cat((mask_x, mask_y), dim=1)
+ mask = mask.to(flow.device)
+ filter_x = torch.tensor([[0, 0, 0.], [0, 1, -1], [0, 0, 0]])
+ filter_y = torch.tensor([[0, 0, 0.], [0, 1, 0], [0, -1, 0]])
+ weights = torch.ones([2, 1, 3, 3])
+ weights[0, 0] = filter_x
+ weights[1, 0] = filter_y
+ weights = weights.to(flow.device)
+
+ flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1)
+ delta_u = F.conv2d(flow_u, weights, stride=1, padding=1)
+ delta_v = F.conv2d(flow_v, weights, stride=1, padding=1)
+ return delta_u, delta_v, mask
+
+
+def second_order_loss(flow, cmask):
+ delta_u, delta_v, mask = second_order_deltas(flow)
+ loss_u = charbonnier_loss(delta_u, cmask)
+ loss_v = charbonnier_loss(delta_v, cmask)
+ return loss_u + loss_v
+
+
+def charbonnier_loss(x, mask=None, truncate=None, alpha=0.45, beta=1.0, epsilon=0.001):
+ """
+ Compute the generalized charbonnier loss of the difference tensor x
+ All positions where mask == 0 are not taken into account
+ x: a tensor of shape [b, c, h, w]
+ mask: a mask of shape [b, mc, h, w], where mask channels must be either 1 or the same as
+ the number of channels of x. Entries should be 0 or 1
+ return: loss
+ """
+ b, c, h, w = x.shape
+ norm = b * c * h * w
+ error = torch.pow(torch.square(x * beta) + torch.square(torch.tensor(epsilon)), alpha)
+ if mask is not None:
+ error = mask * error
+ if truncate is not None:
+ error = torch.min(error, truncate)
+ return torch.sum(error) / norm
+
+
+def second_order_deltas(flow):
+ """
+ consider the single flow first
+ flow shape: [b, c, h, w]
+ """
+ # create mask
+ mask_x = create_mask(flow, [[0, 0], [1, 1]])
+ mask_y = create_mask(flow, [[1, 1], [0, 0]])
+ mask_diag = create_mask(flow, [[1, 1], [1, 1]])
+ mask = torch.cat((mask_x, mask_y, mask_diag, mask_diag), dim=1)
+ mask = mask.to(flow.device)
+
+ filter_x = torch.tensor([[0, 0, 0.], [1, -2, 1], [0, 0, 0]])
+ filter_y = torch.tensor([[0, 1, 0.], [0, -2, 0], [0, 1, 0]])
+ filter_diag1 = torch.tensor([[1, 0, 0.], [0, -2, 0], [0, 0, 1]])
+ filter_diag2 = torch.tensor([[0, 0, 1.], [0, -2, 0], [1, 0, 0]])
+ weights = torch.ones([4, 1, 3, 3])
+ weights[0] = filter_x
+ weights[1] = filter_y
+ weights[2] = filter_diag1
+ weights[3] = filter_diag2
+ weights = weights.to(flow.device)
+
+ # split the flow into flow_u and flow_v, conv them with the weights
+ flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1)
+ delta_u = F.conv2d(flow_u, weights, stride=1, padding=1)
+ delta_v = F.conv2d(flow_v, weights, stride=1, padding=1)
+ return delta_u, delta_v, mask
+
+def create_mask(tensor, paddings):
+ """
+ tensor shape: [b, c, h, w]
+ paddings: [2 x 2] shape list, the first row indicates up and down paddings
+ the second row indicates left and right paddings
+ | |
+ | x |
+ | x * x |
+ | x |
+ | |
+ """
+ shape = tensor.shape
+ inner_height = shape[2] - (paddings[0][0] + paddings[0][1])
+ inner_width = shape[3] - (paddings[1][0] + paddings[1][1])
+ inner = torch.ones([inner_height, inner_width])
+ torch_paddings = [paddings[1][0], paddings[1][1], paddings[0][0], paddings[0][1]] # left, right, up and down
+ mask2d = F.pad(inner, pad=torch_paddings)
+ mask3d = mask2d.unsqueeze(0).repeat(shape[0], 1, 1)
+ mask4d = mask3d.unsqueeze(1)
+ return mask4d.detach()
+
+def ternary_loss(flow_comp, flow_gt, mask, current_frame, shift_frame, scale_factor=1):
+ if scale_factor != 1:
+ current_frame = F.interpolate(current_frame, scale_factor=1 / scale_factor, mode='bilinear')
+ shift_frame = F.interpolate(shift_frame, scale_factor=1 / scale_factor, mode='bilinear')
+ warped_sc = flow_warp(shift_frame, flow_gt.permute(0, 2, 3, 1))
+ noc_mask = torch.exp(-50. * torch.sum(torch.abs(current_frame - warped_sc), dim=1).pow(2)).unsqueeze(1)
+ warped_comp_sc = flow_warp(shift_frame, flow_comp.permute(0, 2, 3, 1))
+ loss = ternary_loss2(current_frame, warped_comp_sc, noc_mask, mask)
+ return loss
+
+class FlowLoss(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.l1_criterion = nn.L1Loss()
+
+ def forward(self, pred_flows, gt_flows, masks, frames):
+ # pred_flows: b t-1 2 h w
+ loss = 0
+ warp_loss = 0
+ h, w = pred_flows[0].shape[-2:]
+ masks = [masks[:,:-1,...].contiguous(), masks[:, 1:, ...].contiguous()]
+ frames0 = frames[:,:-1,...]
+ frames1 = frames[:,1:,...]
+ current_frames = [frames0, frames1]
+ next_frames = [frames1, frames0]
+ for i in range(len(pred_flows)):
+ # print(pred_flows[i].shape)
+ combined_flow = pred_flows[i] * masks[i] + gt_flows[i] * (1-masks[i])
+ l1_loss = self.l1_criterion(pred_flows[i] * masks[i], gt_flows[i] * masks[i]) / torch.mean(masks[i])
+ l1_loss += self.l1_criterion(pred_flows[i] * (1-masks[i]), gt_flows[i] * (1-masks[i])) / torch.mean((1-masks[i]))
+
+ smooth_loss = smoothness_loss(combined_flow.reshape(-1,2,h,w), masks[i].reshape(-1,1,h,w))
+ smooth_loss2 = second_order_loss(combined_flow.reshape(-1,2,h,w), masks[i].reshape(-1,1,h,w))
+
+ warp_loss_i = ternary_loss(combined_flow.reshape(-1,2,h,w), gt_flows[i].reshape(-1,2,h,w),
+ masks[i].reshape(-1,1,h,w), current_frames[i].reshape(-1,3,h,w), next_frames[i].reshape(-1,3,h,w))
+
+ loss += l1_loss + smooth_loss + smooth_loss2
+
+ warp_loss += warp_loss_i
+
+ return loss, warp_loss
+
+
+def edgeLoss(preds_edges, edges):
+ """
+
+ Args:
+ preds_edges: with shape [b, c, h , w]
+ edges: with shape [b, c, h, w]
+
+ Returns: Edge losses
+
+ """
+ mask = (edges > 0.5).float()
+ b, c, h, w = mask.shape
+ num_pos = torch.sum(mask, dim=[1, 2, 3]).float() # Shape: [b,].
+ num_neg = c * h * w - num_pos # Shape: [b,].
+ neg_weights = (num_neg / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3)
+ pos_weights = (num_pos / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3)
+ weight = neg_weights * mask + pos_weights * (1 - mask) # weight for debug
+ losses = F.binary_cross_entropy_with_logits(preds_edges.float(), edges.float(), weight=weight, reduction='none')
+ loss = torch.mean(losses)
+ return loss
+
+class EdgeLoss(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, pred_edges, gt_edges, masks):
+ # pred_flows: b t-1 1 h w
+ loss = 0
+ h, w = pred_edges[0].shape[-2:]
+ masks = [masks[:,:-1,...].contiguous(), masks[:, 1:, ...].contiguous()]
+ for i in range(len(pred_edges)):
+ # print(f'edges_{i}', torch.sum(gt_edges[i])) # debug
+ combined_edge = pred_edges[i] * masks[i] + gt_edges[i] * (1-masks[i])
+ edge_loss = (edgeLoss(pred_edges[i].reshape(-1,1,h,w), gt_edges[i].reshape(-1,1,h,w)) \
+ + 5 * edgeLoss(combined_edge.reshape(-1,1,h,w), gt_edges[i].reshape(-1,1,h,w)))
+ loss += edge_loss
+
+ return loss
+
+
+class FlowSimpleLoss(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.l1_criterion = nn.L1Loss()
+
+ def forward(self, pred_flows, gt_flows):
+ # pred_flows: b t-1 2 h w
+ loss = 0
+ h, w = pred_flows[0].shape[-2:]
+ h_orig, w_orig = gt_flows[0].shape[-2:]
+ pred_flows = [f.view(-1, 2, h, w) for f in pred_flows]
+ gt_flows = [f.view(-1, 2, h_orig, w_orig) for f in gt_flows]
+
+ ds_factor = 1.0*h/h_orig
+ gt_flows = [F.interpolate(f, scale_factor=ds_factor, mode='area') * ds_factor for f in gt_flows]
+ for i in range(len(pred_flows)):
+ loss += self.l1_criterion(pred_flows[i], gt_flows[i])
+
+ return loss
\ No newline at end of file
diff --git a/propainter/model/modules/flow_loss_utils.py b/propainter/model/modules/flow_loss_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e465c0605df760920b5cfc7f9079fadb74fbec1
--- /dev/null
+++ b/propainter/model/modules/flow_loss_utils.py
@@ -0,0 +1,142 @@
+import torch
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+
+def flow_warp(x,
+ flow,
+ interpolation='bilinear',
+ padding_mode='zeros',
+ align_corners=True):
+ """Warp an image or a feature map with optical flow.
+ Args:
+ x (Tensor): Tensor with size (n, c, h, w).
+ flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is
+ a two-channel, denoting the width and height relative offsets.
+ Note that the values are not normalized to [-1, 1].
+ interpolation (str): Interpolation mode: 'nearest' or 'bilinear'.
+ Default: 'bilinear'.
+ padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'.
+ Default: 'zeros'.
+ align_corners (bool): Whether align corners. Default: True.
+ Returns:
+ Tensor: Warped image or feature map.
+ """
+ if x.size()[-2:] != flow.size()[1:3]:
+ raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and '
+ f'flow ({flow.size()[1:3]}) are not the same.')
+ _, _, h, w = x.size()
+ # create mesh grid
+ device = flow.device
+ grid_y, grid_x = torch.meshgrid(torch.arange(0, h, device=device), torch.arange(0, w, device=device))
+ grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (w, h, 2)
+ grid.requires_grad = False
+
+ grid_flow = grid + flow
+ # scale grid_flow to [-1,1]
+ grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
+ grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
+ grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3)
+ output = F.grid_sample(x,
+ grid_flow,
+ mode=interpolation,
+ padding_mode=padding_mode,
+ align_corners=align_corners)
+ return output
+
+
+# def image_warp(image, flow):
+# b, c, h, w = image.size()
+# device = image.device
+# flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0), flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1) # normalize to [-1~1](from upper left to lower right
+# flow = flow.permute(0, 2, 3, 1) # if you wanna use grid_sample function, the channel(band) shape of show must be in the last dimension
+# x = np.linspace(-1, 1, w)
+# y = np.linspace(-1, 1, h)
+# X, Y = np.meshgrid(x, y)
+# grid = torch.cat((torch.from_numpy(X.astype('float32')).unsqueeze(0).unsqueeze(3),
+# torch.from_numpy(Y.astype('float32')).unsqueeze(0).unsqueeze(3)), 3).to(device)
+# output = torch.nn.functional.grid_sample(image, grid + flow, mode='bilinear', padding_mode='zeros')
+# return output
+
+
+def length_sq(x):
+ return torch.sum(torch.square(x), dim=1, keepdim=True)
+
+
+def fbConsistencyCheck(flow_fw, flow_bw, alpha1=0.01, alpha2=0.5):
+ flow_bw_warped = flow_warp(flow_bw, flow_fw.permute(0, 2, 3, 1)) # wb(wf(x))
+ flow_fw_warped = flow_warp(flow_fw, flow_bw.permute(0, 2, 3, 1)) # wf(wb(x))
+ flow_diff_fw = flow_fw + flow_bw_warped # wf + wb(wf(x))
+ flow_diff_bw = flow_bw + flow_fw_warped # wb + wf(wb(x))
+
+ mag_sq_fw = length_sq(flow_fw) + length_sq(flow_bw_warped) # |wf| + |wb(wf(x))|
+ mag_sq_bw = length_sq(flow_bw) + length_sq(flow_fw_warped) # |wb| + |wf(wb(x))|
+ occ_thresh_fw = alpha1 * mag_sq_fw + alpha2
+ occ_thresh_bw = alpha1 * mag_sq_bw + alpha2
+
+ fb_occ_fw = (length_sq(flow_diff_fw) > occ_thresh_fw).float()
+ fb_occ_bw = (length_sq(flow_diff_bw) > occ_thresh_bw).float()
+
+ return fb_occ_fw, fb_occ_bw # fb_occ_fw -> frame2 area occluded by frame1, fb_occ_bw -> frame1 area occluded by frame2
+
+
+def rgb2gray(image):
+ gray_image = image[:, 0] * 0.299 + image[:, 1] * 0.587 + 0.110 * image[:, 2]
+ gray_image = gray_image.unsqueeze(1)
+ return gray_image
+
+
+def ternary_transform(image, max_distance=1):
+ device = image.device
+ patch_size = 2 * max_distance + 1
+ intensities = rgb2gray(image) * 255
+ out_channels = patch_size * patch_size
+ w = np.eye(out_channels).reshape(out_channels, 1, patch_size, patch_size)
+ weights = torch.from_numpy(w).float().to(device)
+ patches = F.conv2d(intensities, weights, stride=1, padding=1)
+ transf = patches - intensities
+ transf_norm = transf / torch.sqrt(0.81 + torch.square(transf))
+ return transf_norm
+
+
+def hamming_distance(t1, t2):
+ dist = torch.square(t1 - t2)
+ dist_norm = dist / (0.1 + dist)
+ dist_sum = torch.sum(dist_norm, dim=1, keepdim=True)
+ return dist_sum
+
+
+def create_mask(mask, paddings):
+ """
+ padding: [[top, bottom], [left, right]]
+ """
+ shape = mask.shape
+ inner_height = shape[2] - (paddings[0][0] + paddings[0][1])
+ inner_width = shape[3] - (paddings[1][0] + paddings[1][1])
+ inner = torch.ones([inner_height, inner_width])
+
+ mask2d = F.pad(inner, pad=[paddings[1][0], paddings[1][1], paddings[0][0], paddings[0][1]])
+ mask3d = mask2d.unsqueeze(0)
+ mask4d = mask3d.unsqueeze(0).repeat(shape[0], 1, 1, 1)
+ return mask4d.detach()
+
+
+def ternary_loss2(frame1, warp_frame21, confMask, masks, max_distance=1):
+ """
+
+ Args:
+ frame1: torch tensor, with shape [b * t, c, h, w]
+ warp_frame21: torch tensor, with shape [b * t, c, h, w]
+ confMask: confidence mask, with shape [b * t, c, h, w]
+ masks: torch tensor, with shape [b * t, c, h, w]
+ max_distance: maximum distance.
+
+ Returns: ternary loss
+
+ """
+ t1 = ternary_transform(frame1)
+ t21 = ternary_transform(warp_frame21)
+ dist = hamming_distance(t1, t21)
+ loss = torch.mean(dist * confMask * masks) / torch.mean(masks)
+ return loss
+
diff --git a/propainter/model/modules/sparse_transformer.py b/propainter/model/modules/sparse_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..11028ffe05a0f59e9d222b0a18f92b1fde12007b
--- /dev/null
+++ b/propainter/model/modules/sparse_transformer.py
@@ -0,0 +1,344 @@
+import math
+from functools import reduce
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+class SoftSplit(nn.Module):
+ def __init__(self, channel, hidden, kernel_size, stride, padding):
+ super(SoftSplit, self).__init__()
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.t2t = nn.Unfold(kernel_size=kernel_size,
+ stride=stride,
+ padding=padding)
+ c_in = reduce((lambda x, y: x * y), kernel_size) * channel
+ self.embedding = nn.Linear(c_in, hidden)
+
+ def forward(self, x, b, output_size):
+ f_h = int((output_size[0] + 2 * self.padding[0] -
+ (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1)
+ f_w = int((output_size[1] + 2 * self.padding[1] -
+ (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1)
+
+ feat = self.t2t(x)
+ feat = feat.permute(0, 2, 1)
+ # feat shape [b*t, num_vec, ks*ks*c]
+ feat = self.embedding(feat)
+ # feat shape after embedding [b, t*num_vec, hidden]
+ feat = feat.view(b, -1, f_h, f_w, feat.size(2))
+ return feat
+
+
+class SoftComp(nn.Module):
+ def __init__(self, channel, hidden, kernel_size, stride, padding):
+ super(SoftComp, self).__init__()
+ self.relu = nn.LeakyReLU(0.2, inplace=True)
+ c_out = reduce((lambda x, y: x * y), kernel_size) * channel
+ self.embedding = nn.Linear(hidden, c_out)
+ self.kernel_size = kernel_size
+ self.stride = stride
+ self.padding = padding
+ self.bias_conv = nn.Conv2d(channel,
+ channel,
+ kernel_size=3,
+ stride=1,
+ padding=1)
+
+ def forward(self, x, t, output_size):
+ b_, _, _, _, c_ = x.shape
+ x = x.view(b_, -1, c_)
+ feat = self.embedding(x)
+ b, _, c = feat.size()
+ feat = feat.view(b * t, -1, c).permute(0, 2, 1)
+ feat = F.fold(feat,
+ output_size=output_size,
+ kernel_size=self.kernel_size,
+ stride=self.stride,
+ padding=self.padding)
+ feat = self.bias_conv(feat)
+ return feat
+
+
+class FusionFeedForward(nn.Module):
+ def __init__(self, dim, hidden_dim=1960, t2t_params=None):
+ super(FusionFeedForward, self).__init__()
+ # We set hidden_dim as a default to 1960
+ self.fc1 = nn.Sequential(nn.Linear(dim, hidden_dim))
+ self.fc2 = nn.Sequential(nn.GELU(), nn.Linear(hidden_dim, dim))
+ assert t2t_params is not None
+ self.t2t_params = t2t_params
+ self.kernel_shape = reduce((lambda x, y: x * y), t2t_params['kernel_size']) # 49
+
+ def forward(self, x, output_size):
+ n_vecs = 1
+ for i, d in enumerate(self.t2t_params['kernel_size']):
+ n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] -
+ (d - 1) - 1) / self.t2t_params['stride'][i] + 1)
+
+ x = self.fc1(x)
+ b, n, c = x.size()
+ normalizer = x.new_ones(b, n, self.kernel_shape).view(-1, n_vecs, self.kernel_shape).permute(0, 2, 1)
+ normalizer = F.fold(normalizer,
+ output_size=output_size,
+ kernel_size=self.t2t_params['kernel_size'],
+ padding=self.t2t_params['padding'],
+ stride=self.t2t_params['stride'])
+
+ x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1),
+ output_size=output_size,
+ kernel_size=self.t2t_params['kernel_size'],
+ padding=self.t2t_params['padding'],
+ stride=self.t2t_params['stride'])
+
+ x = F.unfold(x / normalizer,
+ kernel_size=self.t2t_params['kernel_size'],
+ padding=self.t2t_params['padding'],
+ stride=self.t2t_params['stride']).permute(
+ 0, 2, 1).contiguous().view(b, n, c)
+ x = self.fc2(x)
+ return x
+
+
+def window_partition(x, window_size, n_head):
+ """
+ Args:
+ x: shape is (B, T, H, W, C)
+ window_size (tuple[int]): window size
+ Returns:
+ windows: (B, num_windows_h, num_windows_w, n_head, T, window_size, window_size, C//n_head)
+ """
+ B, T, H, W, C = x.shape
+ x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1], window_size[1], n_head, C//n_head)
+ windows = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
+ return windows
+
+class SparseWindowAttention(nn.Module):
+ def __init__(self, dim, n_head, window_size, pool_size=(4,4), qkv_bias=True, attn_drop=0., proj_drop=0.,
+ pooling_token=True):
+ super().__init__()
+ assert dim % n_head == 0
+ # key, query, value projections for all heads
+ self.key = nn.Linear(dim, dim, qkv_bias)
+ self.query = nn.Linear(dim, dim, qkv_bias)
+ self.value = nn.Linear(dim, dim, qkv_bias)
+ # regularization
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj_drop = nn.Dropout(proj_drop)
+ # output projection
+ self.proj = nn.Linear(dim, dim)
+ self.n_head = n_head
+ self.window_size = window_size
+ self.pooling_token = pooling_token
+ if self.pooling_token:
+ ks, stride = pool_size, pool_size
+ self.pool_layer = nn.Conv2d(dim, dim, kernel_size=ks, stride=stride, padding=(0, 0), groups=dim)
+ self.pool_layer.weight.data.fill_(1. / (pool_size[0] * pool_size[1]))
+ self.pool_layer.bias.data.fill_(0)
+ # self.expand_size = tuple(i // 2 for i in window_size)
+ self.expand_size = tuple((i + 1) // 2 for i in window_size)
+
+ if any(i > 0 for i in self.expand_size):
+ # get mask for rolled k and rolled v
+ mask_tl = torch.ones(self.window_size[0], self.window_size[1])
+ mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0
+ mask_tr = torch.ones(self.window_size[0], self.window_size[1])
+ mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0
+ mask_bl = torch.ones(self.window_size[0], self.window_size[1])
+ mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0
+ mask_br = torch.ones(self.window_size[0], self.window_size[1])
+ mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0
+ masrool_k = torch.stack((mask_tl, mask_tr, mask_bl, mask_br), 0).flatten(0)
+ self.register_buffer("valid_ind_rolled", masrool_k.nonzero(as_tuple=False).view(-1))
+
+ self.max_pool = nn.MaxPool2d(window_size, window_size, (0, 0))
+
+
+ def forward(self, x, mask=None, T_ind=None, attn_mask=None):
+ b, t, h, w, c = x.shape # 20 36
+ w_h, w_w = self.window_size[0], self.window_size[1]
+ c_head = c // self.n_head
+ n_wh = math.ceil(h / self.window_size[0])
+ n_ww = math.ceil(w / self.window_size[1])
+ new_h = n_wh * self.window_size[0] # 20
+ new_w = n_ww * self.window_size[1] # 36
+ pad_r = new_w - w
+ pad_b = new_h - h
+ # reverse order
+ if pad_r > 0 or pad_b > 0:
+ x = F.pad(x,(0, 0, 0, pad_r, 0, pad_b, 0, 0), mode='constant', value=0)
+ mask = F.pad(mask,(0, 0, 0, pad_r, 0, pad_b, 0, 0), mode='constant', value=0)
+
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
+ q = self.query(x)
+ k = self.key(x)
+ v = self.value(x)
+ win_q = window_partition(q.contiguous(), self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head)
+ win_k = window_partition(k.contiguous(), self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head)
+ win_v = window_partition(v.contiguous(), self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head)
+ # roll_k and roll_v
+ if any(i > 0 for i in self.expand_size):
+ (k_tl, v_tl) = map(lambda a: torch.roll(a, shifts=(-self.expand_size[0], -self.expand_size[1]), dims=(2, 3)), (k, v))
+ (k_tr, v_tr) = map(lambda a: torch.roll(a, shifts=(-self.expand_size[0], self.expand_size[1]), dims=(2, 3)), (k, v))
+ (k_bl, v_bl) = map(lambda a: torch.roll(a, shifts=(self.expand_size[0], -self.expand_size[1]), dims=(2, 3)), (k, v))
+ (k_br, v_br) = map(lambda a: torch.roll(a, shifts=(self.expand_size[0], self.expand_size[1]), dims=(2, 3)), (k, v))
+
+ (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
+ lambda a: window_partition(a, self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head),
+ (k_tl, k_tr, k_bl, k_br))
+ (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
+ lambda a: window_partition(a, self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head),
+ (v_tl, v_tr, v_bl, v_br))
+ rool_k = torch.cat((k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows), 4).contiguous()
+ rool_v = torch.cat((v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows), 4).contiguous() # [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head]
+ # mask out tokens in current window
+ rool_k = rool_k[:, :, :, :, self.valid_ind_rolled]
+ rool_v = rool_v[:, :, :, :, self.valid_ind_rolled]
+ roll_N = rool_k.shape[4]
+ rool_k = rool_k.view(b, n_wh*n_ww, self.n_head, t, roll_N, c // self.n_head)
+ rool_v = rool_v.view(b, n_wh*n_ww, self.n_head, t, roll_N, c // self.n_head)
+ win_k = torch.cat((win_k, rool_k), dim=4)
+ win_v = torch.cat((win_v, rool_v), dim=4)
+ else:
+ win_k = win_k
+ win_v = win_v
+
+ # pool_k and pool_v
+ if self.pooling_token:
+ pool_x = self.pool_layer(x.view(b*t, new_h, new_w, c).permute(0,3,1,2))
+ _, _, p_h, p_w = pool_x.shape
+ pool_x = pool_x.permute(0,2,3,1).view(b, t, p_h, p_w, c)
+ # pool_k
+ pool_k = self.key(pool_x).unsqueeze(1).repeat(1, n_wh*n_ww, 1, 1, 1, 1) # [b, n_wh*n_ww, t, p_h, p_w, c]
+ pool_k = pool_k.view(b, n_wh*n_ww, t, p_h, p_w, self.n_head, c_head).permute(0,1,5,2,3,4,6)
+ pool_k = pool_k.contiguous().view(b, n_wh*n_ww, self.n_head, t, p_h*p_w, c_head)
+ win_k = torch.cat((win_k, pool_k), dim=4)
+ # pool_v
+ pool_v = self.value(pool_x).unsqueeze(1).repeat(1, n_wh*n_ww, 1, 1, 1, 1) # [b, n_wh*n_ww, t, p_h, p_w, c]
+ pool_v = pool_v.view(b, n_wh*n_ww, t, p_h, p_w, self.n_head, c_head).permute(0,1,5,2,3,4,6)
+ pool_v = pool_v.contiguous().view(b, n_wh*n_ww, self.n_head, t, p_h*p_w, c_head)
+ win_v = torch.cat((win_v, pool_v), dim=4)
+
+ # [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head]
+ out = torch.zeros_like(win_q)
+ l_t = mask.size(1)
+
+ mask = self.max_pool(mask.view(b * l_t, new_h, new_w))
+ mask = mask.view(b, l_t, n_wh*n_ww)
+ mask = torch.sum(mask, dim=1) # [b, n_wh*n_ww]
+ for i in range(win_q.shape[0]):
+ ### For masked windows
+ mask_ind_i = mask[i].nonzero(as_tuple=False).view(-1)
+ # mask out quary in current window
+ # [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head]
+ mask_n = len(mask_ind_i)
+ if mask_n > 0:
+ win_q_t = win_q[i, mask_ind_i].view(mask_n, self.n_head, t*w_h*w_w, c_head)
+ win_k_t = win_k[i, mask_ind_i]
+ win_v_t = win_v[i, mask_ind_i]
+ # mask out key and value
+ if T_ind is not None:
+ # key [n_wh*n_ww, n_head, t, w_h*w_w, c_head]
+ win_k_t = win_k_t[:, :, T_ind.view(-1)].view(mask_n, self.n_head, -1, c_head)
+ # value
+ win_v_t = win_v_t[:, :, T_ind.view(-1)].view(mask_n, self.n_head, -1, c_head)
+ else:
+ win_k_t = win_k_t.view(n_wh*n_ww, self.n_head, t*w_h*w_w, c_head)
+ win_v_t = win_v_t.view(n_wh*n_ww, self.n_head, t*w_h*w_w, c_head)
+
+ att_t = (win_q_t @ win_k_t.transpose(-2, -1)) * (1.0 / math.sqrt(win_q_t.size(-1)))
+ att_t = F.softmax(att_t, dim=-1)
+ att_t = self.attn_drop(att_t)
+ y_t = att_t @ win_v_t
+
+ out[i, mask_ind_i] = y_t.view(-1, self.n_head, t, w_h*w_w, c_head)
+
+ ### For unmasked windows
+ unmask_ind_i = (mask[i] == 0).nonzero(as_tuple=False).view(-1)
+ # mask out quary in current window
+ # [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head]
+ win_q_s = win_q[i, unmask_ind_i]
+ win_k_s = win_k[i, unmask_ind_i, :, :, :w_h*w_w]
+ win_v_s = win_v[i, unmask_ind_i, :, :, :w_h*w_w]
+
+ att_s = (win_q_s @ win_k_s.transpose(-2, -1)) * (1.0 / math.sqrt(win_q_s.size(-1)))
+ att_s = F.softmax(att_s, dim=-1)
+ att_s = self.attn_drop(att_s)
+ y_s = att_s @ win_v_s
+ out[i, unmask_ind_i] = y_s
+
+ # re-assemble all head outputs side by side
+ out = out.view(b, n_wh, n_ww, self.n_head, t, w_h, w_w, c_head)
+ out = out.permute(0, 4, 1, 5, 2, 6, 3, 7).contiguous().view(b, t, new_h, new_w, c)
+
+
+ if pad_r > 0 or pad_b > 0:
+ out = out[:, :, :h, :w, :]
+
+ # output projection
+ out = self.proj_drop(self.proj(out))
+ return out
+
+
+class TemporalSparseTransformer(nn.Module):
+ def __init__(self, dim, n_head, window_size, pool_size,
+ norm_layer=nn.LayerNorm, t2t_params=None):
+ super().__init__()
+ self.window_size = window_size
+ self.attention = SparseWindowAttention(dim, n_head, window_size, pool_size)
+ self.norm1 = norm_layer(dim)
+ self.norm2 = norm_layer(dim)
+ self.mlp = FusionFeedForward(dim, t2t_params=t2t_params)
+
+ def forward(self, x, fold_x_size, mask=None, T_ind=None):
+ """
+ Args:
+ x: image tokens, shape [B T H W C]
+ fold_x_size: fold feature size, shape [60 108]
+ mask: mask tokens, shape [B T H W 1]
+ Returns:
+ out_tokens: shape [B T H W C]
+ """
+ B, T, H, W, C = x.shape # 20 36
+
+ shortcut = x
+ x = self.norm1(x)
+ att_x = self.attention(x, mask, T_ind)
+
+ # FFN
+ x = shortcut + att_x
+ y = self.norm2(x)
+ x = x + self.mlp(y.view(B, T * H * W, C), fold_x_size).view(B, T, H, W, C)
+
+ return x
+
+
+class TemporalSparseTransformerBlock(nn.Module):
+ def __init__(self, dim, n_head, window_size, pool_size, depths, t2t_params=None):
+ super().__init__()
+ blocks = []
+ for i in range(depths):
+ blocks.append(
+ TemporalSparseTransformer(dim, n_head, window_size, pool_size, t2t_params=t2t_params)
+ )
+ self.transformer = nn.Sequential(*blocks)
+ self.depths = depths
+
+ def forward(self, x, fold_x_size, l_mask=None, t_dilation=2):
+ """
+ Args:
+ x: image tokens, shape [B T H W C]
+ fold_x_size: fold feature size, shape [60 108]
+ l_mask: local mask tokens, shape [B T H W 1]
+ Returns:
+ out_tokens: shape [B T H W C]
+ """
+ assert self.depths % t_dilation == 0, 'wrong t_dilation input.'
+ T = x.size(1)
+ T_ind = [torch.arange(i, T, t_dilation) for i in range(t_dilation)] * (self.depths // t_dilation)
+
+ for i in range(0, self.depths):
+ x = self.transformer[i](x, fold_x_size, l_mask, T_ind[i])
+
+ return x
diff --git a/propainter/model/modules/spectral_norm.py b/propainter/model/modules/spectral_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f38c34e98c03caa28ce0b15a4083215fb7d8e9af
--- /dev/null
+++ b/propainter/model/modules/spectral_norm.py
@@ -0,0 +1,288 @@
+"""
+Spectral Normalization from https://arxiv.org/abs/1802.05957
+"""
+import torch
+from torch.nn.functional import normalize
+
+
+class SpectralNorm(object):
+ # Invariant before and after each forward call:
+ # u = normalize(W @ v)
+ # NB: At initialization, this invariant is not enforced
+
+ _version = 1
+
+ # At version 1:
+ # made `W` not a buffer,
+ # added `v` as a buffer, and
+ # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`.
+
+ def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12):
+ self.name = name
+ self.dim = dim
+ if n_power_iterations <= 0:
+ raise ValueError(
+ 'Expected n_power_iterations to be positive, but '
+ 'got n_power_iterations={}'.format(n_power_iterations))
+ self.n_power_iterations = n_power_iterations
+ self.eps = eps
+
+ def reshape_weight_to_matrix(self, weight):
+ weight_mat = weight
+ if self.dim != 0:
+ # permute dim to front
+ weight_mat = weight_mat.permute(
+ self.dim,
+ *[d for d in range(weight_mat.dim()) if d != self.dim])
+ height = weight_mat.size(0)
+ return weight_mat.reshape(height, -1)
+
+ def compute_weight(self, module, do_power_iteration):
+ # NB: If `do_power_iteration` is set, the `u` and `v` vectors are
+ # updated in power iteration **in-place**. This is very important
+ # because in `DataParallel` forward, the vectors (being buffers) are
+ # broadcast from the parallelized module to each module replica,
+ # which is a new module object created on the fly. And each replica
+ # runs its own spectral norm power iteration. So simply assigning
+ # the updated vectors to the module this function runs on will cause
+ # the update to be lost forever. And the next time the parallelized
+ # module is replicated, the same randomly initialized vectors are
+ # broadcast and used!
+ #
+ # Therefore, to make the change propagate back, we rely on two
+ # important behaviors (also enforced via tests):
+ # 1. `DataParallel` doesn't clone storage if the broadcast tensor
+ # is already on correct device; and it makes sure that the
+ # parallelized module is already on `device[0]`.
+ # 2. If the out tensor in `out=` kwarg has correct shape, it will
+ # just fill in the values.
+ # Therefore, since the same power iteration is performed on all
+ # devices, simply updating the tensors in-place will make sure that
+ # the module replica on `device[0]` will update the _u vector on the
+ # parallized module (by shared storage).
+ #
+ # However, after we update `u` and `v` in-place, we need to **clone**
+ # them before using them to normalize the weight. This is to support
+ # backproping through two forward passes, e.g., the common pattern in
+ # GAN training: loss = D(real) - D(fake). Otherwise, engine will
+ # complain that variables needed to do backward for the first forward
+ # (i.e., the `u` and `v` vectors) are changed in the second forward.
+ weight = getattr(module, self.name + '_orig')
+ u = getattr(module, self.name + '_u')
+ v = getattr(module, self.name + '_v')
+ weight_mat = self.reshape_weight_to_matrix(weight)
+
+ if do_power_iteration:
+ with torch.no_grad():
+ for _ in range(self.n_power_iterations):
+ # Spectral norm of weight equals to `u^T W v`, where `u` and `v`
+ # are the first left and right singular vectors.
+ # This power iteration produces approximations of `u` and `v`.
+ v = normalize(torch.mv(weight_mat.t(), u),
+ dim=0,
+ eps=self.eps,
+ out=v)
+ u = normalize(torch.mv(weight_mat, v),
+ dim=0,
+ eps=self.eps,
+ out=u)
+ if self.n_power_iterations > 0:
+ # See above on why we need to clone
+ u = u.clone()
+ v = v.clone()
+
+ sigma = torch.dot(u, torch.mv(weight_mat, v))
+ weight = weight / sigma
+ return weight
+
+ def remove(self, module):
+ with torch.no_grad():
+ weight = self.compute_weight(module, do_power_iteration=False)
+ delattr(module, self.name)
+ delattr(module, self.name + '_u')
+ delattr(module, self.name + '_v')
+ delattr(module, self.name + '_orig')
+ module.register_parameter(self.name,
+ torch.nn.Parameter(weight.detach()))
+
+ def __call__(self, module, inputs):
+ setattr(
+ module, self.name,
+ self.compute_weight(module, do_power_iteration=module.training))
+
+ def _solve_v_and_rescale(self, weight_mat, u, target_sigma):
+ # Tries to returns a vector `v` s.t. `u = normalize(W @ v)`
+ # (the invariant at top of this class) and `u @ W @ v = sigma`.
+ # This uses pinverse in case W^T W is not invertible.
+ v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(),
+ weight_mat.t(), u.unsqueeze(1)).squeeze(1)
+ return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v)))
+
+ @staticmethod
+ def apply(module, name, n_power_iterations, dim, eps):
+ for k, hook in module._forward_pre_hooks.items():
+ if isinstance(hook, SpectralNorm) and hook.name == name:
+ raise RuntimeError(
+ "Cannot register two spectral_norm hooks on "
+ "the same parameter {}".format(name))
+
+ fn = SpectralNorm(name, n_power_iterations, dim, eps)
+ weight = module._parameters[name]
+
+ with torch.no_grad():
+ weight_mat = fn.reshape_weight_to_matrix(weight)
+
+ h, w = weight_mat.size()
+ # randomly initialize `u` and `v`
+ u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps)
+ v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps)
+
+ delattr(module, fn.name)
+ module.register_parameter(fn.name + "_orig", weight)
+ # We still need to assign weight back as fn.name because all sorts of
+ # things may assume that it exists, e.g., when initializing weights.
+ # However, we can't directly assign as it could be an nn.Parameter and
+ # gets added as a parameter. Instead, we register weight.data as a plain
+ # attribute.
+ setattr(module, fn.name, weight.data)
+ module.register_buffer(fn.name + "_u", u)
+ module.register_buffer(fn.name + "_v", v)
+
+ module.register_forward_pre_hook(fn)
+
+ module._register_state_dict_hook(SpectralNormStateDictHook(fn))
+ module._register_load_state_dict_pre_hook(
+ SpectralNormLoadStateDictPreHook(fn))
+ return fn
+
+
+# This is a top level class because Py2 pickle doesn't like inner class nor an
+# instancemethod.
+class SpectralNormLoadStateDictPreHook(object):
+ # See docstring of SpectralNorm._version on the changes to spectral_norm.
+ def __init__(self, fn):
+ self.fn = fn
+
+ # For state_dict with version None, (assuming that it has gone through at
+ # least one training forward), we have
+ #
+ # u = normalize(W_orig @ v)
+ # W = W_orig / sigma, where sigma = u @ W_orig @ v
+ #
+ # To compute `v`, we solve `W_orig @ x = u`, and let
+ # v = x / (u @ W_orig @ x) * (W / W_orig).
+ def __call__(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ fn = self.fn
+ version = local_metadata.get('spectral_norm',
+ {}).get(fn.name + '.version', None)
+ if version is None or version < 1:
+ with torch.no_grad():
+ weight_orig = state_dict[prefix + fn.name + '_orig']
+ # weight = state_dict.pop(prefix + fn.name)
+ # sigma = (weight_orig / weight).mean()
+ weight_mat = fn.reshape_weight_to_matrix(weight_orig)
+ u = state_dict[prefix + fn.name + '_u']
+ # v = fn._solve_v_and_rescale(weight_mat, u, sigma)
+ # state_dict[prefix + fn.name + '_v'] = v
+
+
+# This is a top level class because Py2 pickle doesn't like inner class nor an
+# instancemethod.
+class SpectralNormStateDictHook(object):
+ # See docstring of SpectralNorm._version on the changes to spectral_norm.
+ def __init__(self, fn):
+ self.fn = fn
+
+ def __call__(self, module, state_dict, prefix, local_metadata):
+ if 'spectral_norm' not in local_metadata:
+ local_metadata['spectral_norm'] = {}
+ key = self.fn.name + '.version'
+ if key in local_metadata['spectral_norm']:
+ raise RuntimeError(
+ "Unexpected key in metadata['spectral_norm']: {}".format(key))
+ local_metadata['spectral_norm'][key] = self.fn._version
+
+
+def spectral_norm(module,
+ name='weight',
+ n_power_iterations=1,
+ eps=1e-12,
+ dim=None):
+ r"""Applies spectral normalization to a parameter in the given module.
+
+ .. math::
+ \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})},
+ \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2}
+
+ Spectral normalization stabilizes the training of discriminators (critics)
+ in Generative Adversarial Networks (GANs) by rescaling the weight tensor
+ with spectral norm :math:`\sigma` of the weight matrix calculated using
+ power iteration method. If the dimension of the weight tensor is greater
+ than 2, it is reshaped to 2D in power iteration method to get spectral
+ norm. This is implemented via a hook that calculates spectral norm and
+ rescales weight before every :meth:`~Module.forward` call.
+
+ See `Spectral Normalization for Generative Adversarial Networks`_ .
+
+ .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957
+
+ Args:
+ module (nn.Module): containing module
+ name (str, optional): name of weight parameter
+ n_power_iterations (int, optional): number of power iterations to
+ calculate spectral norm
+ eps (float, optional): epsilon for numerical stability in
+ calculating norms
+ dim (int, optional): dimension corresponding to number of outputs,
+ the default is ``0``, except for modules that are instances of
+ ConvTranspose{1,2,3}d, when it is ``1``
+
+ Returns:
+ The original module with the spectral norm hook
+
+ Example::
+
+ >>> m = spectral_norm(nn.Linear(20, 40))
+ >>> m
+ Linear(in_features=20, out_features=40, bias=True)
+ >>> m.weight_u.size()
+ torch.Size([40])
+
+ """
+ if dim is None:
+ if isinstance(module,
+ (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d,
+ torch.nn.ConvTranspose3d)):
+ dim = 1
+ else:
+ dim = 0
+ SpectralNorm.apply(module, name, n_power_iterations, dim, eps)
+ return module
+
+
+def remove_spectral_norm(module, name='weight'):
+ r"""Removes the spectral normalization reparameterization from a module.
+
+ Args:
+ module (Module): containing module
+ name (str, optional): name of weight parameter
+
+ Example:
+ >>> m = spectral_norm(nn.Linear(40, 10))
+ >>> remove_spectral_norm(m)
+ """
+ for k, hook in module._forward_pre_hooks.items():
+ if isinstance(hook, SpectralNorm) and hook.name == name:
+ hook.remove(module)
+ del module._forward_pre_hooks[k]
+ return module
+
+ raise ValueError("spectral_norm of '{}' not found in {}".format(
+ name, module))
+
+
+def use_spectral_norm(module, use_sn=False):
+ if use_sn:
+ return spectral_norm(module)
+ return module
\ No newline at end of file
diff --git a/propainter/model/propainter.py b/propainter/model/propainter.py
new file mode 100644
index 0000000000000000000000000000000000000000..71ec0c551b2eaa114aa1e74cd19bcf1542fb3b04
--- /dev/null
+++ b/propainter/model/propainter.py
@@ -0,0 +1,553 @@
+''' Towards An End-to-End Framework for Video Inpainting
+'''
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision
+
+from einops import rearrange
+
+try:
+ from model.modules.base_module import BaseNetwork
+ from model.modules.sparse_transformer import TemporalSparseTransformerBlock, SoftSplit, SoftComp
+ from model.modules.spectral_norm import spectral_norm as _spectral_norm
+ from model.modules.flow_loss_utils import flow_warp
+ from model.modules.deformconv import ModulatedDeformConv2d
+
+ from .misc import constant_init
+except:
+ from propainter.model.modules.base_module import BaseNetwork
+ from propainter.model.modules.sparse_transformer import TemporalSparseTransformerBlock, SoftSplit, SoftComp
+ from propainter.model.modules.spectral_norm import spectral_norm as _spectral_norm
+ from propainter.model.modules.flow_loss_utils import flow_warp
+ from propainter.model.modules.deformconv import ModulatedDeformConv2d
+
+ from propainter.model.misc import constant_init
+
+
+def length_sq(x):
+ return torch.sum(torch.square(x), dim=1, keepdim=True)
+
+def fbConsistencyCheck(flow_fw, flow_bw, alpha1=0.01, alpha2=0.5): #debug
+ flow_bw_warped = flow_warp(flow_bw, flow_fw.permute(0, 2, 3, 1)) # wb(wf(x))
+ flow_diff_fw = flow_fw + flow_bw_warped # wf + wb(wf(x))
+
+ mag_sq_fw = length_sq(flow_fw) + length_sq(flow_bw_warped) # |wf| + |wb(wf(x))|
+ occ_thresh_fw = alpha1 * mag_sq_fw + alpha2
+
+ # fb_valid_fw = (length_sq(flow_diff_fw) < occ_thresh_fw).float()
+ fb_valid_fw = (length_sq(flow_diff_fw) < occ_thresh_fw).to(flow_fw)
+ return fb_valid_fw
+
+
+class DeformableAlignment(ModulatedDeformConv2d):
+ """Second-order deformable alignment module."""
+ def __init__(self, *args, **kwargs):
+ # self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10)
+ self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 3)
+
+ super(DeformableAlignment, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Sequential(
+ nn.Conv2d(2*self.out_channels + 2 + 1 + 2, self.out_channels, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1),
+ )
+ self.init_offset()
+
+ def init_offset(self):
+ constant_init(self.conv_offset[-1], val=0, bias=0)
+
+ def forward(self, x, cond_feat, flow):
+ out = self.conv_offset(cond_feat)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+
+ # offset
+ offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
+ offset = offset + flow.flip(1).repeat(1, offset.size(1) // 2, 1, 1)
+
+ # mask
+ mask = torch.sigmoid(mask)
+
+ return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias,
+ self.stride, self.padding,
+ self.dilation, mask)
+
+
+class BidirectionalPropagation(nn.Module):
+ def __init__(self, channel, learnable=True):
+ super(BidirectionalPropagation, self).__init__()
+ self.deform_align = nn.ModuleDict()
+ self.backbone = nn.ModuleDict()
+ self.channel = channel
+ self.prop_list = ['backward_1', 'forward_1']
+ self.learnable = learnable
+
+ if self.learnable:
+ for i, module in enumerate(self.prop_list):
+ self.deform_align[module] = DeformableAlignment(
+ channel, channel, 3, padding=1, deform_groups=16)
+
+ self.backbone[module] = nn.Sequential(
+ nn.Conv2d(2*channel+2, channel, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(channel, channel, 3, 1, 1),
+ )
+
+ self.fuse = nn.Sequential(
+ nn.Conv2d(2*channel+2, channel, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(channel, channel, 3, 1, 1),
+ )
+
+ def binary_mask(self, mask, th=0.1):
+ mask[mask>th] = 1
+ mask[mask<=th] = 0
+ # return mask.float()
+ return mask.to(mask)
+
+ def forward(self, x, flows_forward, flows_backward, mask, interpolation='bilinear', direction='forward'):
+ """
+ x shape : [b, t, c, h, w]
+ return [b, t, c, h, w]
+ """
+
+ # For backward warping
+ # pred_flows_forward for backward feature propagation
+ # pred_flows_backward for forward feature propagation
+ b, t, c, h, w = x.shape
+ feats, masks = {}, {}
+ feats['input'] = [x[:, i, :, :, :] for i in range(0, t)]
+ masks['input'] = [mask[:, i, :, :, :] for i in range(0, t)]
+
+ prop_list = ['backward_1', 'forward_1']
+ cache_list = ['input'] + prop_list
+
+ for p_i, module_name in enumerate(prop_list):
+ feats[module_name] = []
+ masks[module_name] = []
+
+ if 'backward' in module_name:
+ frame_idx = range(0, t)
+ frame_idx = frame_idx[::-1]
+ flow_idx = frame_idx
+ flows_for_prop = flows_forward
+ flows_for_check = flows_backward
+ else:
+ frame_idx = range(0, t)
+ flow_idx = range(-1, t - 1)
+ flows_for_prop = flows_backward
+ flows_for_check = flows_forward
+
+ len_frames_idx = len(frame_idx)
+ for i, idx in enumerate(frame_idx):
+ feat_current = feats[cache_list[p_i]][idx]
+ mask_current = masks[cache_list[p_i]][idx]
+
+ if i == 0:
+ feat_prop = feat_current
+ mask_prop = mask_current
+ else:
+ flow_prop = flows_for_prop[:, flow_idx[i], :, :, :]
+ flow_check = flows_for_check[:, flow_idx[i], :, :, :]
+ flow_vaild_mask = fbConsistencyCheck(flow_prop, flow_check)
+ feat_warped = flow_warp(feat_prop, flow_prop.permute(0, 2, 3, 1), interpolation)
+ feat_warped = torch.clamp(feat_warped, min=-1.0, max=1.0)
+
+ if self.learnable:
+ cond = torch.cat([feat_current, feat_warped, flow_prop, flow_vaild_mask, mask_current], dim=1)
+ feat_prop = self.deform_align[module_name](feat_prop, cond, flow_prop)
+ mask_prop = mask_current
+ else:
+ mask_prop_valid = flow_warp(mask_prop, flow_prop.permute(0, 2, 3, 1))
+ mask_prop_valid = self.binary_mask(mask_prop_valid)
+
+ union_vaild_mask = self.binary_mask(mask_current*flow_vaild_mask*(1-mask_prop_valid))
+ feat_prop = union_vaild_mask * feat_warped + (1-union_vaild_mask) * feat_current
+ # update mask
+ mask_prop = self.binary_mask(mask_current*(1-(flow_vaild_mask*(1-mask_prop_valid))))
+
+
+ # refine
+ if self.learnable:
+ feat = torch.cat([feat_current, feat_prop, mask_current], dim=1)
+ feat_prop = feat_prop + self.backbone[module_name](feat)
+ # feat_prop = self.backbone[module_name](feat_prop)
+
+ feats[module_name].append(feat_prop)
+ masks[module_name].append(mask_prop)
+
+
+ # end for
+ if 'backward' in module_name:
+ feats[module_name] = feats[module_name][::-1]
+ masks[module_name] = masks[module_name][::-1]
+
+ outputs_b = torch.stack(feats['backward_1'], dim=1).view(-1, c, h, w)
+ outputs_f = torch.stack(feats['forward_1'], dim=1).view(-1, c, h, w)
+
+ if self.learnable:
+ mask_in = mask.view(-1, 2, h, w)
+ masks_b, masks_f = None, None
+ outputs = self.fuse(torch.cat([outputs_b, outputs_f, mask_in], dim=1)) + x.view(-1, c, h, w)
+ else:
+ if direction == 'forward':
+ masks_b = torch.stack(masks['backward_1'], dim=1)
+ masks_f = torch.stack(masks['forward_1'], dim=1)
+ outputs = outputs_f
+ else:
+ masks_b = torch.stack(masks['backward_1'], dim=1)
+ masks_f = torch.stack(masks['forward_1'], dim=1)
+ outputs = outputs_b
+ return outputs_b.view(b, -1, c, h, w), outputs_f.view(b, -1, c, h, w), \
+ outputs.view(b, -1, c, h, w), masks_b
+
+ return outputs_b.view(b, -1, c, h, w), outputs_f.view(b, -1, c, h, w), \
+ outputs.view(b, -1, c, h, w), masks_f
+
+
+class Encoder(nn.Module):
+ def __init__(self):
+ super(Encoder, self).__init__()
+ self.group = [1, 2, 4, 8, 1]
+ self.layers = nn.ModuleList([
+ nn.Conv2d(5, 64, kernel_size=3, stride=2, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1),
+ nn.LeakyReLU(0.2, inplace=True)
+ ])
+
+ def forward(self, x):
+ bt, c, _, _ = x.size()
+ # h, w = h//4, w//4
+ out = x
+ for i, layer in enumerate(self.layers):
+ if i == 8:
+ x0 = out
+ _, _, h, w = x0.size()
+ if i > 8 and i % 2 == 0:
+ g = self.group[(i - 8) // 2]
+ x = x0.view(bt, g, -1, h, w)
+ o = out.view(bt, g, -1, h, w)
+ out = torch.cat([x, o], 2).view(bt, -1, h, w)
+ out = layer(out)
+ return out
+
+
+class deconv(nn.Module):
+ def __init__(self,
+ input_channel,
+ output_channel,
+ kernel_size=3,
+ padding=0):
+ super().__init__()
+ self.conv = nn.Conv2d(input_channel,
+ output_channel,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=padding)
+
+ def forward(self, x):
+ x = F.interpolate(x,
+ scale_factor=2,
+ mode='bilinear',
+ align_corners=True)
+ return self.conv(x)
+
+
+class InpaintGenerator(BaseNetwork):
+ def __init__(self, init_weights=True, model_path=None):
+ super(InpaintGenerator, self).__init__()
+ channel = 128
+ hidden = 512
+
+ # encoder
+ self.encoder = Encoder()
+
+ # decoder
+ self.decoder = nn.Sequential(
+ deconv(channel, 128, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ deconv(64, 64, kernel_size=3, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1))
+
+ # soft split and soft composition
+ kernel_size = (7, 7)
+ padding = (3, 3)
+ stride = (3, 3)
+ t2t_params = {
+ 'kernel_size': kernel_size,
+ 'stride': stride,
+ 'padding': padding
+ }
+ self.ss = SoftSplit(channel, hidden, kernel_size, stride, padding)
+ self.sc = SoftComp(channel, hidden, kernel_size, stride, padding)
+ self.max_pool = nn.MaxPool2d(kernel_size, stride, padding)
+
+ # feature propagation module
+ self.img_prop_module = BidirectionalPropagation(3, learnable=False)
+ self.feat_prop_module = BidirectionalPropagation(128, learnable=True)
+
+
+ depths = 8
+ num_heads = 4
+ window_size = (5, 9)
+ pool_size = (4, 4)
+ self.transformers = TemporalSparseTransformerBlock(dim=hidden,
+ n_head=num_heads,
+ window_size=window_size,
+ pool_size=pool_size,
+ depths=depths,
+ t2t_params=t2t_params)
+ if init_weights:
+ self.init_weights()
+
+
+ if model_path is not None:
+ # print('Pretrained ProPainter has loaded...')
+ ckpt = torch.load(model_path, map_location='cpu')
+ self.load_state_dict(ckpt, strict=True)
+
+ # print network parameter number
+ # self.print_network()
+
+ def img_propagation(self, masked_frames, completed_flows, masks, interpolation='nearest', direction = 'forward'):
+ _, _, prop_frames, updated_masks = self.img_prop_module(masked_frames, completed_flows[0], completed_flows[1], masks, interpolation, direction)
+ return prop_frames, updated_masks
+
+ def forward(self, masked_frames, completed_flows, masks_in, masks_updated, num_local_frames, interpolation='bilinear', t_dilation=2):
+ """
+ Args:
+ masks_in: original mask
+ masks_updated: updated mask after image propagation
+ """
+
+ l_t = num_local_frames
+ b, t, _, ori_h, ori_w = masked_frames.size()
+
+ # extracting features
+ enc_feat = self.encoder(torch.cat([masked_frames.view(b * t, 3, ori_h, ori_w),
+ masks_in.view(b * t, 1, ori_h, ori_w),
+ masks_updated.view(b * t, 1, ori_h, ori_w)], dim=1))
+ _, c, h, w = enc_feat.size()
+ local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...]
+ ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...]
+ fold_feat_size = (h, w)
+
+ ds_flows_f = F.interpolate(completed_flows[0].view(-1, 2, ori_h, ori_w), scale_factor=1/4, mode='bilinear', align_corners=False).view(b, l_t-1, 2, h, w)/4.0
+ ds_flows_b = F.interpolate(completed_flows[1].view(-1, 2, ori_h, ori_w), scale_factor=1/4, mode='bilinear', align_corners=False).view(b, l_t-1, 2, h, w)/4.0
+ ds_mask_in = F.interpolate(masks_in.reshape(-1, 1, ori_h, ori_w), scale_factor=1/4, mode='nearest').view(b, t, 1, h, w)
+ ds_mask_in_local = ds_mask_in[:, :l_t]
+ ds_mask_updated_local = F.interpolate(masks_updated[:,:l_t].reshape(-1, 1, ori_h, ori_w), scale_factor=1/4, mode='nearest').view(b, l_t, 1, h, w)
+
+
+ if self.training:
+ mask_pool_l = self.max_pool(ds_mask_in.view(-1, 1, h, w))
+ mask_pool_l = mask_pool_l.view(b, t, 1, mask_pool_l.size(-2), mask_pool_l.size(-1))
+ else:
+ mask_pool_l = self.max_pool(ds_mask_in_local.view(-1, 1, h, w))
+ mask_pool_l = mask_pool_l.view(b, l_t, 1, mask_pool_l.size(-2), mask_pool_l.size(-1))
+
+
+ prop_mask_in = torch.cat([ds_mask_in_local, ds_mask_updated_local], dim=2)
+ _, _, local_feat, _ = self.feat_prop_module(local_feat, ds_flows_f, ds_flows_b, prop_mask_in, interpolation)
+ enc_feat = torch.cat((local_feat, ref_feat), dim=1)
+
+ trans_feat = self.ss(enc_feat.view(-1, c, h, w), b, fold_feat_size)
+ mask_pool_l = rearrange(mask_pool_l, 'b t c h w -> b t h w c').contiguous()
+ trans_feat = self.transformers(trans_feat, fold_feat_size, mask_pool_l, t_dilation=t_dilation)
+ trans_feat = self.sc(trans_feat, t, fold_feat_size)
+ trans_feat = trans_feat.view(b, t, -1, h, w)
+
+ enc_feat = enc_feat + trans_feat
+
+ if self.training:
+ output = self.decoder(enc_feat.view(-1, c, h, w))
+ output = torch.tanh(output).view(b, t, 3, ori_h, ori_w)
+ else:
+ output = self.decoder(enc_feat[:, :l_t].view(-1, c, h, w))
+ output = torch.tanh(output).view(b, l_t, 3, ori_h, ori_w)
+
+ return output
+
+
+# ######################################################################
+# Discriminator for Temporal Patch GAN
+# ######################################################################
+class Discriminator(BaseNetwork):
+ def __init__(self,
+ in_channels=3,
+ use_sigmoid=False,
+ use_spectral_norm=True,
+ init_weights=True):
+ super(Discriminator, self).__init__()
+ self.use_sigmoid = use_sigmoid
+ nf = 32
+
+ self.conv = nn.Sequential(
+ spectral_norm(
+ nn.Conv3d(in_channels=in_channels,
+ out_channels=nf * 1,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=1,
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(64, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 1,
+ nf * 2,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(128, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 2,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(256, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 4,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(256, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 4,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(256, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv3d(nf * 4,
+ nf * 4,
+ kernel_size=(3, 5, 5),
+ stride=(1, 2, 2),
+ padding=(1, 2, 2)))
+
+ if init_weights:
+ self.init_weights()
+
+ def forward(self, xs):
+ # T, C, H, W = xs.shape (old)
+ # B, T, C, H, W (new)
+ xs_t = torch.transpose(xs, 1, 2)
+ feat = self.conv(xs_t)
+ if self.use_sigmoid:
+ feat = torch.sigmoid(feat)
+ out = torch.transpose(feat, 1, 2) # B, T, C, H, W
+ return out
+
+
+class Discriminator_2D(BaseNetwork):
+ def __init__(self,
+ in_channels=3,
+ use_sigmoid=False,
+ use_spectral_norm=True,
+ init_weights=True):
+ super(Discriminator_2D, self).__init__()
+ self.use_sigmoid = use_sigmoid
+ nf = 32
+
+ self.conv = nn.Sequential(
+ spectral_norm(
+ nn.Conv3d(in_channels=in_channels,
+ out_channels=nf * 1,
+ kernel_size=(1, 5, 5),
+ stride=(1, 2, 2),
+ padding=(0, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(64, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 1,
+ nf * 2,
+ kernel_size=(1, 5, 5),
+ stride=(1, 2, 2),
+ padding=(0, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(128, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 2,
+ nf * 4,
+ kernel_size=(1, 5, 5),
+ stride=(1, 2, 2),
+ padding=(0, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(256, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 4,
+ nf * 4,
+ kernel_size=(1, 5, 5),
+ stride=(1, 2, 2),
+ padding=(0, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(256, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ spectral_norm(
+ nn.Conv3d(nf * 4,
+ nf * 4,
+ kernel_size=(1, 5, 5),
+ stride=(1, 2, 2),
+ padding=(0, 2, 2),
+ bias=not use_spectral_norm), use_spectral_norm),
+ # nn.InstanceNorm2d(256, track_running_stats=False),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv3d(nf * 4,
+ nf * 4,
+ kernel_size=(1, 5, 5),
+ stride=(1, 2, 2),
+ padding=(0, 2, 2)))
+
+ if init_weights:
+ self.init_weights()
+
+ def forward(self, xs):
+ # T, C, H, W = xs.shape (old)
+ # B, T, C, H, W (new)
+ xs_t = torch.transpose(xs, 1, 2)
+ feat = self.conv(xs_t)
+ if self.use_sigmoid:
+ feat = torch.sigmoid(feat)
+ out = torch.transpose(feat, 1, 2) # B, T, C, H, W
+ return out
+
+def spectral_norm(module, mode=True):
+ if mode:
+ return _spectral_norm(module)
+ return module
diff --git a/propainter/model/recurrent_flow_completion.py b/propainter/model/recurrent_flow_completion.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1dcc422e42f4e86e94f6eba2ac962a26f4b1084
--- /dev/null
+++ b/propainter/model/recurrent_flow_completion.py
@@ -0,0 +1,352 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision
+
+try:
+ from model.modules.deformconv import ModulatedDeformConv2d
+ from .misc import constant_init
+except:
+ from propainter.model.modules.deformconv import ModulatedDeformConv2d
+ from propainter.model.misc import constant_init
+
+
+class SecondOrderDeformableAlignment(ModulatedDeformConv2d):
+ """Second-order deformable alignment module."""
+ def __init__(self, *args, **kwargs):
+ self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 5)
+
+ super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs)
+
+ self.conv_offset = nn.Sequential(
+ nn.Conv2d(3 * self.out_channels, self.out_channels, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1),
+ )
+ self.init_offset()
+
+ def init_offset(self):
+ constant_init(self.conv_offset[-1], val=0, bias=0)
+
+ def forward(self, x, extra_feat):
+ out = self.conv_offset(extra_feat)
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
+
+ # offset
+ offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1))
+ offset_1, offset_2 = torch.chunk(offset, 2, dim=1)
+ offset = torch.cat([offset_1, offset_2], dim=1)
+
+ # mask
+ mask = torch.sigmoid(mask)
+
+ return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias,
+ self.stride, self.padding,
+ self.dilation, mask)
+
+class BidirectionalPropagation(nn.Module):
+ def __init__(self, channel):
+ super(BidirectionalPropagation, self).__init__()
+ modules = ['backward_', 'forward_']
+ self.deform_align = nn.ModuleDict()
+ self.backbone = nn.ModuleDict()
+ self.channel = channel
+
+ for i, module in enumerate(modules):
+ self.deform_align[module] = SecondOrderDeformableAlignment(
+ 2 * channel, channel, 3, padding=1, deform_groups=16)
+
+ self.backbone[module] = nn.Sequential(
+ nn.Conv2d((2 + i) * channel, channel, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(channel, channel, 3, 1, 1),
+ )
+
+ self.fusion = nn.Conv2d(2 * channel, channel, 1, 1, 0)
+
+ def forward(self, x):
+ """
+ x shape : [b, t, c, h, w]
+ return [b, t, c, h, w]
+ """
+ b, t, c, h, w = x.shape
+ feats = {}
+ feats['spatial'] = [x[:, i, :, :, :] for i in range(0, t)]
+
+ for module_name in ['backward_', 'forward_']:
+
+ feats[module_name] = []
+
+ frame_idx = range(0, t)
+ mapping_idx = list(range(0, len(feats['spatial'])))
+ mapping_idx += mapping_idx[::-1]
+
+ if 'backward' in module_name:
+ frame_idx = frame_idx[::-1]
+
+ feat_prop = x.new_zeros(b, self.channel, h, w)
+ for i, idx in enumerate(frame_idx):
+ feat_current = feats['spatial'][mapping_idx[idx]]
+ if i > 0:
+ cond_n1 = feat_prop
+
+ # initialize second-order features
+ feat_n2 = torch.zeros_like(feat_prop)
+ cond_n2 = torch.zeros_like(cond_n1)
+ if i > 1: # second-order features
+ feat_n2 = feats[module_name][-2]
+ cond_n2 = feat_n2
+
+ cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1) # condition information, cond(flow warped 1st/2nd feature)
+ feat_prop = torch.cat([feat_prop, feat_n2], dim=1) # two order feat_prop -1 & -2
+ feat_prop = self.deform_align[module_name](feat_prop, cond)
+
+ # fuse current features
+ feat = [feat_current] + \
+ [feats[k][idx] for k in feats if k not in ['spatial', module_name]] \
+ + [feat_prop]
+
+ feat = torch.cat(feat, dim=1)
+ # embed current features
+ feat_prop = feat_prop + self.backbone[module_name](feat)
+
+ feats[module_name].append(feat_prop)
+
+ # end for
+ if 'backward' in module_name:
+ feats[module_name] = feats[module_name][::-1]
+
+ outputs = []
+ for i in range(0, t):
+ align_feats = [feats[k].pop(0) for k in feats if k != 'spatial']
+ align_feats = torch.cat(align_feats, dim=1)
+ outputs.append(self.fusion(align_feats))
+
+ return torch.stack(outputs, dim=1) + x
+
+
+class deconv(nn.Module):
+ def __init__(self,
+ input_channel,
+ output_channel,
+ kernel_size=3,
+ padding=0):
+ super().__init__()
+ self.conv = nn.Conv2d(input_channel,
+ output_channel,
+ kernel_size=kernel_size,
+ stride=1,
+ padding=padding)
+
+ def forward(self, x):
+ x = F.interpolate(x,
+ scale_factor=2,
+ mode='bilinear',
+ align_corners=True)
+ return self.conv(x)
+
+
+class P3DBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, use_residual=0, bias=True):
+ super().__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv3d(in_channels, out_channels, kernel_size=(1, kernel_size, kernel_size),
+ stride=(1, stride, stride), padding=(0, padding, padding), bias=bias),
+ nn.LeakyReLU(0.2, inplace=True)
+ )
+ self.conv2 = nn.Sequential(
+ nn.Conv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1),
+ padding=(2, 0, 0), dilation=(2, 1, 1), bias=bias)
+ )
+ self.use_residual = use_residual
+
+ def forward(self, feats):
+ feat1 = self.conv1(feats)
+ feat2 = self.conv2(feat1)
+ if self.use_residual:
+ output = feats + feat2
+ else:
+ output = feat2
+ return output
+
+
+class EdgeDetection(nn.Module):
+ def __init__(self, in_ch=2, out_ch=1, mid_ch=16):
+ super().__init__()
+ self.projection = nn.Sequential(
+ nn.Conv2d(in_ch, mid_ch, 3, 1, 1),
+ nn.LeakyReLU(0.2, inplace=True)
+ )
+
+ self.mid_layer_1 = nn.Sequential(
+ nn.Conv2d(mid_ch, mid_ch, 3, 1, 1),
+ nn.LeakyReLU(0.2, inplace=True)
+ )
+
+ self.mid_layer_2 = nn.Sequential(
+ nn.Conv2d(mid_ch, mid_ch, 3, 1, 1)
+ )
+
+ self.l_relu = nn.LeakyReLU(0.01, inplace=True)
+
+ self.out_layer = nn.Conv2d(mid_ch, out_ch, 1, 1, 0)
+
+ def forward(self, flow):
+ flow = self.projection(flow)
+ edge = self.mid_layer_1(flow)
+ edge = self.mid_layer_2(edge)
+ edge = self.l_relu(flow + edge)
+ edge = self.out_layer(edge)
+ edge = torch.sigmoid(edge)
+ return edge
+
+
+class RecurrentFlowCompleteNet(nn.Module):
+ def __init__(self, model_path=None):
+ super().__init__()
+ self.downsample = nn.Sequential(
+ nn.Conv3d(3, 32, kernel_size=(1, 5, 5), stride=(1, 2, 2),
+ padding=(0, 2, 2), padding_mode='replicate'),
+ nn.LeakyReLU(0.2, inplace=True)
+ )
+
+ self.encoder1 = nn.Sequential(
+ P3DBlock(32, 32, 3, 1, 1),
+ nn.LeakyReLU(0.2, inplace=True),
+ P3DBlock(32, 64, 3, 2, 1),
+ nn.LeakyReLU(0.2, inplace=True)
+ ) # 4x
+
+ self.encoder2 = nn.Sequential(
+ P3DBlock(64, 64, 3, 1, 1),
+ nn.LeakyReLU(0.2, inplace=True),
+ P3DBlock(64, 128, 3, 2, 1),
+ nn.LeakyReLU(0.2, inplace=True)
+ ) # 8x
+
+ self.mid_dilation = nn.Sequential(
+ nn.Conv3d(128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 3, 3), dilation=(1, 3, 3)), # p = d*(k-1)/2
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv3d(128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 2, 2), dilation=(1, 2, 2)),
+ nn.LeakyReLU(0.2, inplace=True),
+ nn.Conv3d(128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 1, 1), dilation=(1, 1, 1)),
+ nn.LeakyReLU(0.2, inplace=True)
+ )
+
+ # feature propagation module
+ self.feat_prop_module = BidirectionalPropagation(128)
+
+ self.decoder2 = nn.Sequential(
+ nn.Conv2d(128, 128, 3, 1, 1),
+ nn.LeakyReLU(0.2, inplace=True),
+ deconv(128, 64, 3, 1),
+ nn.LeakyReLU(0.2, inplace=True)
+ ) # 4x
+
+ self.decoder1 = nn.Sequential(
+ nn.Conv2d(64, 64, 3, 1, 1),
+ nn.LeakyReLU(0.2, inplace=True),
+ deconv(64, 32, 3, 1),
+ nn.LeakyReLU(0.2, inplace=True)
+ ) # 2x
+
+ self.upsample = nn.Sequential(
+ nn.Conv2d(32, 32, 3, padding=1),
+ nn.LeakyReLU(0.2, inplace=True),
+ deconv(32, 2, 3, 1)
+ )
+
+ # edge loss
+ self.edgeDetector = EdgeDetection(in_ch=2, out_ch=1, mid_ch=16)
+
+ # Need to initial the weights of MSDeformAttn specifically
+ for m in self.modules():
+ if isinstance(m, SecondOrderDeformableAlignment):
+ m.init_offset()
+
+ if model_path is not None:
+ # print('Pretrained flow completion model has loaded...')
+ ckpt = torch.load(model_path, map_location='cpu')
+ self.load_state_dict(ckpt, strict=True)
+
+
+ def forward(self, masked_flows, masks):
+ # masked_flows: b t-1 2 h w
+ # masks: b t-1 2 h w
+ b, t, _, h, w = masked_flows.size()
+ masked_flows = masked_flows.permute(0,2,1,3,4)
+ masks = masks.permute(0,2,1,3,4)
+
+ inputs = torch.cat((masked_flows, masks), dim=1)
+
+ x = self.downsample(inputs)
+
+ feat_e1 = self.encoder1(x)
+ feat_e2 = self.encoder2(feat_e1) # b c t h w
+ feat_mid = self.mid_dilation(feat_e2) # b c t h w
+ feat_mid = feat_mid.permute(0,2,1,3,4) # b t c h w
+
+ feat_prop = self.feat_prop_module(feat_mid)
+ feat_prop = feat_prop.view(-1, 128, h//8, w//8) # b*t c h w
+
+ _, c, _, h_f, w_f = feat_e1.shape
+ feat_e1 = feat_e1.permute(0,2,1,3,4).contiguous().view(-1, c, h_f, w_f) # b*t c h w
+ feat_d2 = self.decoder2(feat_prop) + feat_e1
+
+ _, c, _, h_f, w_f = x.shape
+ x = x.permute(0,2,1,3,4).contiguous().view(-1, c, h_f, w_f) # b*t c h w
+
+ feat_d1 = self.decoder1(feat_d2)
+
+ flow = self.upsample(feat_d1)
+ if self.training:
+ edge = self.edgeDetector(flow)
+ edge = edge.view(b, t, 1, h, w)
+ else:
+ edge = None
+
+ flow = flow.view(b, t, 2, h, w)
+
+ return flow, edge
+
+
+ def forward_bidirect_flow(self, masked_flows_bi, masks):
+ """
+ Args:
+ masked_flows_bi: [masked_flows_f, masked_flows_b] | (b t-1 2 h w), (b t-1 2 h w)
+ masks: b t 1 h w
+ """
+ masks_forward = masks[:, :-1, ...].contiguous()
+ masks_backward = masks[:, 1:, ...].contiguous()
+
+ # mask flow
+ masked_flows_forward = masked_flows_bi[0] * (1-masks_forward)
+ masked_flows_backward = masked_flows_bi[1] * (1-masks_backward)
+
+ # -- completion --
+ # forward
+ pred_flows_forward, pred_edges_forward = self.forward(masked_flows_forward, masks_forward)
+
+ # backward
+ masked_flows_backward = torch.flip(masked_flows_backward, dims=[1])
+ masks_backward = torch.flip(masks_backward, dims=[1])
+ pred_flows_backward, pred_edges_backward = self.forward(masked_flows_backward, masks_backward)
+ pred_flows_backward = torch.flip(pred_flows_backward, dims=[1])
+ if self.training:
+ pred_edges_backward = torch.flip(pred_edges_backward, dims=[1])
+
+ return [pred_flows_forward, pred_flows_backward], [pred_edges_forward, pred_edges_backward]
+
+
+ def combine_flow(self, masked_flows_bi, pred_flows_bi, masks):
+ masks_forward = masks[:, :-1, ...].contiguous()
+ masks_backward = masks[:, 1:, ...].contiguous()
+
+ pred_flows_forward = pred_flows_bi[0] * masks_forward + masked_flows_bi[0] * (1-masks_forward)
+ pred_flows_backward = pred_flows_bi[1] * masks_backward + masked_flows_bi[1] * (1-masks_backward)
+
+ return pred_flows_forward, pred_flows_backward
diff --git a/propainter/model/vgg_arch.py b/propainter/model/vgg_arch.py
new file mode 100644
index 0000000000000000000000000000000000000000..43fc2ff8bc1c73313d632c6ab326372d389a4772
--- /dev/null
+++ b/propainter/model/vgg_arch.py
@@ -0,0 +1,157 @@
+import os
+import torch
+from collections import OrderedDict
+from torch import nn as nn
+from torchvision.models import vgg as vgg
+
+VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
+NAMES = {
+ 'vgg11': [
+ 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
+ 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
+ 'pool5'
+ ],
+ 'vgg13': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
+ 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
+ ],
+ 'vgg16': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
+ 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
+ 'pool5'
+ ],
+ 'vgg19': [
+ 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
+ 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
+ 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
+ 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
+ ]
+}
+
+
+def insert_bn(names):
+ """Insert bn layer after each conv.
+
+ Args:
+ names (list): The list of layer names.
+
+ Returns:
+ list: The list of layer names with bn layers.
+ """
+ names_bn = []
+ for name in names:
+ names_bn.append(name)
+ if 'conv' in name:
+ position = name.replace('conv', '')
+ names_bn.append('bn' + position)
+ return names_bn
+
+class VGGFeatureExtractor(nn.Module):
+ """VGG network for feature extraction.
+
+ In this implementation, we allow users to choose whether use normalization
+ in the input feature and the type of vgg network. Note that the pretrained
+ path must fit the vgg type.
+
+ Args:
+ layer_name_list (list[str]): Forward function returns the corresponding
+ features according to the layer_name_list.
+ Example: {'relu1_1', 'relu2_1', 'relu3_1'}.
+ vgg_type (str): Set the type of vgg network. Default: 'vgg19'.
+ use_input_norm (bool): If True, normalize the input image. Importantly,
+ the input feature must in the range [0, 1]. Default: True.
+ range_norm (bool): If True, norm images with range [-1, 1] to [0, 1].
+ Default: False.
+ requires_grad (bool): If true, the parameters of VGG network will be
+ optimized. Default: False.
+ remove_pooling (bool): If true, the max pooling operations in VGG net
+ will be removed. Default: False.
+ pooling_stride (int): The stride of max pooling operation. Default: 2.
+ """
+
+ def __init__(self,
+ layer_name_list,
+ vgg_type='vgg19',
+ use_input_norm=True,
+ range_norm=False,
+ requires_grad=False,
+ remove_pooling=False,
+ pooling_stride=2):
+ super(VGGFeatureExtractor, self).__init__()
+
+ self.layer_name_list = layer_name_list
+ self.use_input_norm = use_input_norm
+ self.range_norm = range_norm
+
+ self.names = NAMES[vgg_type.replace('_bn', '')]
+ if 'bn' in vgg_type:
+ self.names = insert_bn(self.names)
+
+ # only borrow layers that will be used to avoid unused params
+ max_idx = 0
+ for v in layer_name_list:
+ idx = self.names.index(v)
+ if idx > max_idx:
+ max_idx = idx
+
+ if os.path.exists(VGG_PRETRAIN_PATH):
+ vgg_net = getattr(vgg, vgg_type)(pretrained=False)
+ state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage)
+ vgg_net.load_state_dict(state_dict)
+ else:
+ vgg_net = getattr(vgg, vgg_type)(pretrained=True)
+
+ features = vgg_net.features[:max_idx + 1]
+
+ modified_net = OrderedDict()
+ for k, v in zip(self.names, features):
+ if 'pool' in k:
+ # if remove_pooling is true, pooling operation will be removed
+ if remove_pooling:
+ continue
+ else:
+ # in some cases, we may want to change the default stride
+ modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride)
+ else:
+ modified_net[k] = v
+
+ self.vgg_net = nn.Sequential(modified_net)
+
+ if not requires_grad:
+ self.vgg_net.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+ else:
+ self.vgg_net.train()
+ for param in self.parameters():
+ param.requires_grad = True
+
+ if self.use_input_norm:
+ # the mean is for image with range [0, 1]
+ self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
+ # the std is for image with range [0, 1]
+ self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
+
+ def forward(self, x):
+ """Forward function.
+
+ Args:
+ x (Tensor): Input tensor with shape (n, c, h, w).
+
+ Returns:
+ Tensor: Forward results.
+ """
+ if self.range_norm:
+ x = (x + 1) / 2
+ if self.use_input_norm:
+ x = (x - self.mean) / self.std
+ output = {}
+
+ for key, layer in self.vgg_net._modules.items():
+ x = layer(x)
+ if key in self.layer_name_list:
+ output[key] = x.clone()
+
+ return output
diff --git a/propainter/utils/download_util.py b/propainter/utils/download_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e8fb1b00522309d0c0931f5396355011fb200e7
--- /dev/null
+++ b/propainter/utils/download_util.py
@@ -0,0 +1,109 @@
+import math
+import os
+import requests
+from torch.hub import download_url_to_file, get_dir
+from tqdm import tqdm
+from urllib.parse import urlparse
+
+def sizeof_fmt(size, suffix='B'):
+ """Get human readable file size.
+
+ Args:
+ size (int): File size.
+ suffix (str): Suffix. Default: 'B'.
+
+ Return:
+ str: Formated file siz.
+ """
+ for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
+ if abs(size) < 1024.0:
+ return f'{size:3.1f} {unit}{suffix}'
+ size /= 1024.0
+ return f'{size:3.1f} Y{suffix}'
+
+
+def download_file_from_google_drive(file_id, save_path):
+ """Download files from google drive.
+ Ref:
+ https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
+ Args:
+ file_id (str): File id.
+ save_path (str): Save path.
+ """
+
+ session = requests.Session()
+ URL = 'https://docs.google.com/uc?export=download'
+ params = {'id': file_id}
+
+ response = session.get(URL, params=params, stream=True)
+ token = get_confirm_token(response)
+ if token:
+ params['confirm'] = token
+ response = session.get(URL, params=params, stream=True)
+
+ # get file size
+ response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
+ print(response_file_size)
+ if 'Content-Range' in response_file_size.headers:
+ file_size = int(response_file_size.headers['Content-Range'].split('/')[1])
+ else:
+ file_size = None
+
+ save_response_content(response, save_path, file_size)
+
+
+def get_confirm_token(response):
+ for key, value in response.cookies.items():
+ if key.startswith('download_warning'):
+ return value
+ return None
+
+
+def save_response_content(response, destination, file_size=None, chunk_size=32768):
+ if file_size is not None:
+ pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
+
+ readable_file_size = sizeof_fmt(file_size)
+ else:
+ pbar = None
+
+ with open(destination, 'wb') as f:
+ downloaded_size = 0
+ for chunk in response.iter_content(chunk_size):
+ downloaded_size += chunk_size
+ if pbar is not None:
+ pbar.update(1)
+ pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}')
+ if chunk: # filter out keep-alive new chunks
+ f.write(chunk)
+ if pbar is not None:
+ pbar.close()
+
+
+def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
+ """Load file form http url, will download models if necessary.
+ Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py
+ Args:
+ url (str): URL to be downloaded.
+ model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir.
+ Default: None.
+ progress (bool): Whether to show the download progress. Default: True.
+ file_name (str): The downloaded file name. If None, use the file name in the url. Default: None.
+ Returns:
+ str: The path to the downloaded file.
+ """
+ if model_dir is None: # use the pytorch hub_dir
+ hub_dir = get_dir()
+ model_dir = os.path.join(hub_dir, 'checkpoints')
+
+ os.makedirs(model_dir, exist_ok=True)
+
+ parts = urlparse(url)
+ filename = os.path.basename(parts.path)
+ if file_name is not None:
+ filename = file_name
+ cached_file = os.path.abspath(os.path.join(model_dir, filename))
+ if not os.path.exists(cached_file):
+ print(f'Downloading: "{url}" to {cached_file}\n')
+ download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
+ return cached_file
\ No newline at end of file
diff --git a/propainter/utils/file_client.py b/propainter/utils/file_client.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7578dec8304475b3906d5dcc734a5a58b56f2ad
--- /dev/null
+++ b/propainter/utils/file_client.py
@@ -0,0 +1,166 @@
+from abc import ABCMeta, abstractmethod
+
+
+class BaseStorageBackend(metaclass=ABCMeta):
+ """Abstract class of storage backends.
+
+ All backends need to implement two apis: ``get()`` and ``get_text()``.
+ ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
+ as texts.
+ """
+
+ @abstractmethod
+ def get(self, filepath):
+ pass
+
+ @abstractmethod
+ def get_text(self, filepath):
+ pass
+
+
+class MemcachedBackend(BaseStorageBackend):
+ """Memcached storage backend.
+
+ Attributes:
+ server_list_cfg (str): Config file for memcached server list.
+ client_cfg (str): Config file for memcached client.
+ sys_path (str | None): Additional path to be appended to `sys.path`.
+ Default: None.
+ """
+
+ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
+ if sys_path is not None:
+ import sys
+ sys.path.append(sys_path)
+ try:
+ import mc
+ except ImportError:
+ raise ImportError('Please install memcached to enable MemcachedBackend.')
+
+ self.server_list_cfg = server_list_cfg
+ self.client_cfg = client_cfg
+ self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg)
+ # mc.pyvector servers as a point which points to a memory cache
+ self._mc_buffer = mc.pyvector()
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ import mc
+ self._client.Get(filepath, self._mc_buffer)
+ value_buf = mc.ConvertBuffer(self._mc_buffer)
+ return value_buf
+
+ def get_text(self, filepath):
+ raise NotImplementedError
+
+
+class HardDiskBackend(BaseStorageBackend):
+ """Raw hard disks storage backend."""
+
+ def get(self, filepath):
+ filepath = str(filepath)
+ with open(filepath, 'rb') as f:
+ value_buf = f.read()
+ return value_buf
+
+ def get_text(self, filepath):
+ filepath = str(filepath)
+ with open(filepath, 'r') as f:
+ value_buf = f.read()
+ return value_buf
+
+
+class LmdbBackend(BaseStorageBackend):
+ """Lmdb storage backend.
+
+ Args:
+ db_paths (str | list[str]): Lmdb database paths.
+ client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
+ readonly (bool, optional): Lmdb environment parameter. If True,
+ disallow any write operations. Default: True.
+ lock (bool, optional): Lmdb environment parameter. If False, when
+ concurrent access occurs, do not lock the database. Default: False.
+ readahead (bool, optional): Lmdb environment parameter. If False,
+ disable the OS filesystem readahead mechanism, which may improve
+ random read performance when a database is larger than RAM.
+ Default: False.
+
+ Attributes:
+ db_paths (list): Lmdb database path.
+ _client (list): A list of several lmdb envs.
+ """
+
+ def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs):
+ try:
+ import lmdb
+ except ImportError:
+ raise ImportError('Please install lmdb to enable LmdbBackend.')
+
+ if isinstance(client_keys, str):
+ client_keys = [client_keys]
+
+ if isinstance(db_paths, list):
+ self.db_paths = [str(v) for v in db_paths]
+ elif isinstance(db_paths, str):
+ self.db_paths = [str(db_paths)]
+ assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
+ f'but received {len(client_keys)} and {len(self.db_paths)}.')
+
+ self._client = {}
+ for client, path in zip(client_keys, self.db_paths):
+ self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs)
+
+ def get(self, filepath, client_key):
+ """Get values according to the filepath from one lmdb named client_key.
+
+ Args:
+ filepath (str | obj:`Path`): Here, filepath is the lmdb key.
+ client_key (str): Used for distinguishing differnet lmdb envs.
+ """
+ filepath = str(filepath)
+ assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.')
+ client = self._client[client_key]
+ with client.begin(write=False) as txn:
+ value_buf = txn.get(filepath.encode('ascii'))
+ return value_buf
+
+ def get_text(self, filepath):
+ raise NotImplementedError
+
+
+class FileClient(object):
+ """A general file client to access files in different backend.
+
+ The client loads a file or text in a specified backend from its path
+ and return it as a binary file. it can also register other backend
+ accessor with a given name and backend class.
+
+ Attributes:
+ backend (str): The storage backend type. Options are "disk",
+ "memcached" and "lmdb".
+ client (:obj:`BaseStorageBackend`): The backend object.
+ """
+
+ _backends = {
+ 'disk': HardDiskBackend,
+ 'memcached': MemcachedBackend,
+ 'lmdb': LmdbBackend,
+ }
+
+ def __init__(self, backend='disk', **kwargs):
+ if backend not in self._backends:
+ raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
+ f' are {list(self._backends.keys())}')
+ self.backend = backend
+ self.client = self._backends[backend](**kwargs)
+
+ def get(self, filepath, client_key='default'):
+ # client_key is used only for lmdb, where different fileclients have
+ # different lmdb environments.
+ if self.backend == 'lmdb':
+ return self.client.get(filepath, client_key)
+ else:
+ return self.client.get(filepath)
+
+ def get_text(self, filepath):
+ return self.client.get_text(filepath)
diff --git a/propainter/utils/flow_util.py b/propainter/utils/flow_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..5551e5f173414db8c1088a8fd26b4b0f525fa944
--- /dev/null
+++ b/propainter/utils/flow_util.py
@@ -0,0 +1,196 @@
+import cv2
+import numpy as np
+import os
+import torch.nn.functional as F
+
+def resize_flow(flow, newh, neww):
+ oldh, oldw = flow.shape[0:2]
+ flow = cv2.resize(flow, (neww, newh), interpolation=cv2.INTER_LINEAR)
+ flow[:, :, 0] *= neww / oldw
+ flow[:, :, 1] *= newh / oldh
+ return flow
+
+def resize_flow_pytorch(flow, newh, neww):
+ oldh, oldw = flow.shape[-2:]
+ flow = F.interpolate(flow, (newh, neww), mode='bilinear')
+ flow[:, :, 0] *= neww / oldw
+ flow[:, :, 1] *= newh / oldh
+ return flow
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+ if auto_mkdir:
+ dir_name = os.path.abspath(os.path.dirname(file_path))
+ os.makedirs(dir_name, exist_ok=True)
+ return cv2.imwrite(file_path, img, params)
+
+
+def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
+ """Read an optical flow map.
+
+ Args:
+ flow_path (ndarray or str): Flow path.
+ quantize (bool): whether to read quantized pair, if set to True,
+ remaining args will be passed to :func:`dequantize_flow`.
+ concat_axis (int): The axis that dx and dy are concatenated,
+ can be either 0 or 1. Ignored if quantize is False.
+
+ Returns:
+ ndarray: Optical flow represented as a (h, w, 2) numpy array
+ """
+ if quantize:
+ assert concat_axis in [0, 1]
+ cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
+ if cat_flow.ndim != 2:
+ raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.')
+ assert cat_flow.shape[concat_axis] % 2 == 0
+ dx, dy = np.split(cat_flow, 2, axis=concat_axis)
+ flow = dequantize_flow(dx, dy, *args, **kwargs)
+ else:
+ with open(flow_path, 'rb') as f:
+ try:
+ header = f.read(4).decode('utf-8')
+ except Exception:
+ raise IOError(f'Invalid flow file: {flow_path}')
+ else:
+ if header != 'PIEH':
+ raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH')
+
+ w = np.fromfile(f, np.int32, 1).squeeze()
+ h = np.fromfile(f, np.int32, 1).squeeze()
+ # flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
+ flow = np.fromfile(f, np.float16, w * h * 2).reshape((h, w, 2))
+
+ return flow.astype(np.float32)
+
+
+def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
+ """Write optical flow to file.
+
+ If the flow is not quantized, it will be saved as a .flo file losslessly,
+ otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
+ will be concatenated horizontally into a single image if quantize is True.)
+
+ Args:
+ flow (ndarray): (h, w, 2) array of optical flow.
+ filename (str): Output filepath.
+ quantize (bool): Whether to quantize the flow and save it to 2 jpeg
+ images. If set to True, remaining args will be passed to
+ :func:`quantize_flow`.
+ concat_axis (int): The axis that dx and dy are concatenated,
+ can be either 0 or 1. Ignored if quantize is False.
+ """
+ dir_name = os.path.abspath(os.path.dirname(filename))
+ os.makedirs(dir_name, exist_ok=True)
+ if not quantize:
+ with open(filename, 'wb') as f:
+ f.write('PIEH'.encode('utf-8'))
+ np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
+ # flow = flow.astype(np.float32)
+ flow = flow.astype(np.float16)
+ flow.tofile(f)
+ f.flush()
+ else:
+ assert concat_axis in [0, 1]
+ dx, dy = quantize_flow(flow, *args, **kwargs)
+ dxdy = np.concatenate((dx, dy), axis=concat_axis)
+ # os.makedirs(os.path.dirname(filename), exist_ok=True)
+ cv2.imwrite(filename, dxdy)
+ # imwrite(dxdy, filename)
+
+
+def quantize_flow(flow, max_val=0.02, norm=True):
+ """Quantize flow to [0, 255].
+
+ After this step, the size of flow will be much smaller, and can be
+ dumped as jpeg images.
+
+ Args:
+ flow (ndarray): (h, w, 2) array of optical flow.
+ max_val (float): Maximum value of flow, values beyond
+ [-max_val, max_val] will be truncated.
+ norm (bool): Whether to divide flow values by image width/height.
+
+ Returns:
+ tuple[ndarray]: Quantized dx and dy.
+ """
+ h, w, _ = flow.shape
+ dx = flow[..., 0]
+ dy = flow[..., 1]
+ if norm:
+ dx = dx / w # avoid inplace operations
+ dy = dy / h
+ # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
+ flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]]
+ return tuple(flow_comps)
+
+
+def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
+ """Recover from quantized flow.
+
+ Args:
+ dx (ndarray): Quantized dx.
+ dy (ndarray): Quantized dy.
+ max_val (float): Maximum value used when quantizing.
+ denorm (bool): Whether to multiply flow values with width/height.
+
+ Returns:
+ ndarray: Dequantized flow.
+ """
+ assert dx.shape == dy.shape
+ assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
+
+ dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
+
+ if denorm:
+ dx *= dx.shape[1]
+ dy *= dx.shape[0]
+ flow = np.dstack((dx, dy))
+ return flow
+
+
+def quantize(arr, min_val, max_val, levels, dtype=np.int64):
+ """Quantize an array of (-inf, inf) to [0, levels-1].
+
+ Args:
+ arr (ndarray): Input array.
+ min_val (scalar): Minimum value to be clipped.
+ max_val (scalar): Maximum value to be clipped.
+ levels (int): Quantization levels.
+ dtype (np.type): The type of the quantized array.
+
+ Returns:
+ tuple: Quantized array.
+ """
+ if not (isinstance(levels, int) and levels > 1):
+ raise ValueError(f'levels must be a positive integer, but got {levels}')
+ if min_val >= max_val:
+ raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+ arr = np.clip(arr, min_val, max_val) - min_val
+ quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
+
+ return quantized_arr
+
+
+def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
+ """Dequantize an array.
+
+ Args:
+ arr (ndarray): Input array.
+ min_val (scalar): Minimum value to be clipped.
+ max_val (scalar): Maximum value to be clipped.
+ levels (int): Quantization levels.
+ dtype (np.type): The type of the dequantized array.
+
+ Returns:
+ tuple: Dequantized array.
+ """
+ if not (isinstance(levels, int) and levels > 1):
+ raise ValueError(f'levels must be a positive integer, but got {levels}')
+ if min_val >= max_val:
+ raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})')
+
+ dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val
+
+ return dequantized_arr
\ No newline at end of file
diff --git a/propainter/utils/img_util.py b/propainter/utils/img_util.py
new file mode 100644
index 0000000000000000000000000000000000000000..d409a132ff216e6943a276fb5d8cd5f410824883
--- /dev/null
+++ b/propainter/utils/img_util.py
@@ -0,0 +1,170 @@
+import cv2
+import math
+import numpy as np
+import os
+import torch
+from torchvision.utils import make_grid
+
+
+def img2tensor(imgs, bgr2rgb=True, float32=True):
+ """Numpy array to tensor.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Input images.
+ bgr2rgb (bool): Whether to change bgr to rgb.
+ float32 (bool): Whether to change to float32.
+
+ Returns:
+ list[tensor] | tensor: Tensor images. If returned results only have
+ one element, just return tensor.
+ """
+
+ def _totensor(img, bgr2rgb, float32):
+ if img.shape[2] == 3 and bgr2rgb:
+ if img.dtype == 'float64':
+ img = img.astype('float32')
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
+ img = torch.from_numpy(img.transpose(2, 0, 1))
+ if float32:
+ img = img.float()
+ return img
+
+ if isinstance(imgs, list):
+ return [_totensor(img, bgr2rgb, float32) for img in imgs]
+ else:
+ return _totensor(imgs, bgr2rgb, float32)
+
+
+def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
+ """Convert torch Tensors into image numpy arrays.
+
+ After clamping to [min, max], values will be normalized to [0, 1].
+
+ Args:
+ tensor (Tensor or list[Tensor]): Accept shapes:
+ 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
+ 2) 3D Tensor of shape (3/1 x H x W);
+ 3) 2D Tensor of shape (H x W).
+ Tensor channel should be in RGB order.
+ rgb2bgr (bool): Whether to change rgb to bgr.
+ out_type (numpy type): output types. If ``np.uint8``, transform outputs
+ to uint8 type with range [0, 255]; otherwise, float type with
+ range [0, 1]. Default: ``np.uint8``.
+ min_max (tuple[int]): min and max values for clamp.
+
+ Returns:
+ (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
+ shape (H x W). The channel order is BGR.
+ """
+ if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))):
+ raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}')
+
+ if torch.is_tensor(tensor):
+ tensor = [tensor]
+ result = []
+ for _tensor in tensor:
+ _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
+ _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
+
+ n_dim = _tensor.dim()
+ if n_dim == 4:
+ img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 3:
+ img_np = _tensor.numpy()
+ img_np = img_np.transpose(1, 2, 0)
+ if img_np.shape[2] == 1: # gray image
+ img_np = np.squeeze(img_np, axis=2)
+ else:
+ if rgb2bgr:
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
+ elif n_dim == 2:
+ img_np = _tensor.numpy()
+ else:
+ raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}')
+ if out_type == np.uint8:
+ # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
+ img_np = (img_np * 255.0).round()
+ img_np = img_np.astype(out_type)
+ result.append(img_np)
+ if len(result) == 1:
+ result = result[0]
+ return result
+
+
+def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)):
+ """This implementation is slightly faster than tensor2img.
+ It now only supports torch tensor with shape (1, c, h, w).
+
+ Args:
+ tensor (Tensor): Now only support torch tensor with (1, c, h, w).
+ rgb2bgr (bool): Whether to change rgb to bgr. Default: True.
+ min_max (tuple[int]): min and max values for clamp.
+ """
+ output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0)
+ output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255
+ output = output.type(torch.uint8).cpu().numpy()
+ if rgb2bgr:
+ output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR)
+ return output
+
+
+def imfrombytes(content, flag='color', float32=False):
+ """Read an image from bytes.
+
+ Args:
+ content (bytes): Image bytes got from files or other streams.
+ flag (str): Flags specifying the color type of a loaded image,
+ candidates are `color`, `grayscale` and `unchanged`.
+ float32 (bool): Whether to change to float32., If True, will also norm
+ to [0, 1]. Default: False.
+
+ Returns:
+ ndarray: Loaded image array.
+ """
+ img_np = np.frombuffer(content, np.uint8)
+ imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
+ img = cv2.imdecode(img_np, imread_flags[flag])
+ if float32:
+ img = img.astype(np.float32) / 255.
+ return img
+
+
+def imwrite(img, file_path, params=None, auto_mkdir=True):
+ """Write image to file.
+
+ Args:
+ img (ndarray): Image array to be written.
+ file_path (str): Image file path.
+ params (None or list): Same as opencv's :func:`imwrite` interface.
+ auto_mkdir (bool): If the parent folder of `file_path` does not exist,
+ whether to create it automatically.
+
+ Returns:
+ bool: Successful or not.
+ """
+ if auto_mkdir:
+ dir_name = os.path.abspath(os.path.dirname(file_path))
+ os.makedirs(dir_name, exist_ok=True)
+ return cv2.imwrite(file_path, img, params)
+
+
+def crop_border(imgs, crop_border):
+ """Crop borders of images.
+
+ Args:
+ imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
+ crop_border (int): Crop border for each end of height and weight.
+
+ Returns:
+ list[ndarray]: Cropped images.
+ """
+ if crop_border == 0:
+ return imgs
+ else:
+ if isinstance(imgs, list):
+ return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs]
+ else:
+ return imgs[crop_border:-crop_border, crop_border:-crop_border, ...]
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..832c939b315597309dcf0324582c296813760fe2
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,19 @@
+torch==2.3.1
+torchvision==0.18.1
+torchaudio==2.3.1
+diffusers==0.29.2
+accelerate==0.25.0
+opencv-python==4.9.0.80
+imageio==2.34.1
+matplotlib
+transformers==4.41.1
+einops==0.8.0
+datasets==2.19.1
+numpy==1.26.4
+pillow==10.4.0
+tqdm==4.66.4
+urllib3==2.2.2
+zipp==3.19.2
+peft==0.13.2
+scipy==1.13.1
+av==14.0.1
\ No newline at end of file
diff --git a/run_diffueraser.py b/run_diffueraser.py
new file mode 100644
index 0000000000000000000000000000000000000000..473c047462f992c156302d8f106f82362bb175fd
--- /dev/null
+++ b/run_diffueraser.py
@@ -0,0 +1,62 @@
+import torch
+import os
+import time
+import argparse
+from diffueraser.diffueraser import DiffuEraser
+from propainter.inference import Propainter, get_device
+
+def main():
+
+ ## input params
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--input_video', type=str, default="examples/example3/video.mp4", help='Path to the input video')
+ parser.add_argument('--input_mask', type=str, default="examples/example3/mask.mp4" , help='Path to the input mask')
+ parser.add_argument('--video_length', type=int, default=10, help='The maximum length of output video')
+ parser.add_argument('--mask_dilation_iter', type=int, default=8, help='Adjust it to change the degree of mask expansion')
+ parser.add_argument('--max_img_size', type=int, default=960, help='The maximum length of output width and height')
+ parser.add_argument('--save_path', type=str, default="results" , help='Path to the output')
+ parser.add_argument('--ref_stride', type=int, default=10, help='Propainter params')
+ parser.add_argument('--neighbor_length', type=int, default=10, help='Propainter params')
+ parser.add_argument('--subvideo_length', type=int, default=50, help='Propainter params')
+ parser.add_argument('--base_model_path', type=str, default="weights/stable-diffusion-v1-5" , help='Path to sd1.5 base model')
+ parser.add_argument('--vae_path', type=str, default="weights/sd-vae-ft-mse" , help='Path to vae')
+ parser.add_argument('--diffueraser_path', type=str, default="weights/diffuEraser" , help='Path to DiffuEraser')
+ parser.add_argument('--propainter_model_dir', type=str, default="weights/propainter" , help='Path to priori model')
+ args = parser.parse_args()
+
+ if not os.path.exists(args.save_path):
+ os.makedirs(args.save_path)
+ priori_path = os.path.join(args.save_path, "priori.mp4")
+ output_path = os.path.join(args.save_path, "diffueraser_result.mp4")
+
+ ## model initialization
+ device = get_device()
+ # PCM params
+ ckpt = "2-Step"
+ video_inpainting_sd = DiffuEraser(device, args.base_model_path, args.vae_path, args.diffueraser_path, ckpt=ckpt)
+ propainter = Propainter(args.propainter_model_dir, device=device)
+
+ start_time = time.time()
+
+ ## priori
+ propainter.forward(args.input_video, args.input_mask, priori_path, video_length=args.video_length,
+ ref_stride=args.ref_stride, neighbor_length=args.neighbor_length, subvideo_length = args.subvideo_length,
+ mask_dilation = args.mask_dilation_iter)
+
+ ## diffueraser
+ guidance_scale = None # The default value is 0.
+ video_inpainting_sd.forward(args.input_video, args.input_mask, priori_path, output_path,
+ max_img_size = args.max_img_size, video_length=args.video_length, mask_dilation_iter=args.mask_dilation_iter,
+ guidance_scale=guidance_scale)
+
+ end_time = time.time()
+ inference_time = end_time - start_time
+ print(f"DiffuEraser inference time: {inference_time:.4f} s")
+
+ torch.cuda.empty_cache()
+
+if __name__ == '__main__':
+ main()
+
+
+
\ No newline at end of file
diff --git a/weights/README.md b/weights/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..d6b8403cd6916408580cbc55bc4c83a586c6a2bc
--- /dev/null
+++ b/weights/README.md
@@ -0,0 +1,22 @@
+Put the downloaded pre-trained models to this folder.
+
+The directory structure will be arranged as:
+```
+weights
+ |- diffuEraser
+ |-brushnet
+ |-unet_main
+ |- stable-diffusion-v1-5
+ |-feature_extractor
+ |-...
+ |- PCM_Weights
+ |-sd15
+ |- propainter
+ |-ProPainter.pth
+ |-raft-things.pth
+ |-recurrent_flow_completion.pth
+ |- sd-vae-ft-mse
+ |-diffusion_pytorch_model.bin
+ |-...
+ |- README.md
+```
\ No newline at end of file