diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..cb5a34493dd6a0dbed9c02f2f6bc88e3e1ec3e3a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +examples/example2/mask.mp4 filter=lfs diff=lfs merge=lfs -text +examples/example3/video.mp4 filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..7a4a3ea2424c09fbe48d455aed1eaa94d9124835 --- /dev/null +++ b/LICENSE @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/ORIGINAL_README.md b/ORIGINAL_README.md new file mode 100644 index 0000000000000000000000000000000000000000..bc7e5c167422ac28bf4702f7121a6b03528f415b --- /dev/null +++ b/ORIGINAL_README.md @@ -0,0 +1,171 @@ +
+ +

DiffuEraser: A Diffusion Model for Video Inpainting

+ +
+ Xiaowen Li  + Haolan Xue  + Peiran Ren  + Liefeng Bo +
+
+ Tongyi Lab, Alibaba Group  +
+ +
+ TECHNICAL REPORT +
+ +
+

+ + + + + + +

+
+ + + + +
+ +DiffuEraser is a diffusion model for video inpainting, which outperforms state-of-the-art model Propainter in both content completeness and temporal consistency while maintaining acceptable efficiency. + +--- + + +## Update Log +- *2025.01.20*: Release inference code. + + +## TODO +- [ ] Release training code. +- [ ] Release HuggingFace/ModelScope demo. +- [ ] Release gradio demo. + + +## Results +More results will be displayed on the project page. + +https://github.com/user-attachments/assets/b59d0b88-4186-4531-8698-adf6e62058f8 + + + + +## Method Overview +Our network is inspired by [BrushNet](https://github.com/TencentARC/BrushNet) and [Animatediff](https://github.com/guoyww/AnimateDiff). The architecture comprises the primary `denoising UNet` and an auxiliary `BrushNet branch`. Features extracted by BrushNet branch are integrated into the denoising UNet layer by layer after a zero convolution block. The denoising UNet performs the denoising process to generate the final output. To enhance temporal consistency, `temporal attention` mechanisms are incorporated following both self-attention and cross-attention layers. After denoising, the generated images are blended with the input masked images using blurred masks. + +![overall_structure](assets/DiffuEraser_pipeline.png) + +We incorporate `prior` information to provide initialization and weak conditioning, which helps mitigate noisy artifacts and suppress hallucinations. +Additionally, to improve temporal consistency during long-sequence inference, we expand the `temporal receptive fields` of both the prior model and DiffuEraser, and further enhance consistency by leveraging the temporal smoothing capabilities of Video Diffusion Models. Please read the paper for details. + + +## Getting Started + +#### Installation + +1. Clone Repo + + ```bash + git clone https://github.com/lixiaowen-xw/DiffuEraser.git + ``` + +2. Create Conda Environment and Install Dependencies + + ```bash + # create new anaconda env + conda create -n diffueraser python=3.9.19 + conda activate diffueraser + # install python dependencies + pip install -r requirements.txt + ``` + +#### Prepare pretrained models +Weights will be placed under the `./weights` directory. +1. Download our pretrained models from [Hugging Face](https://huggingface.co/lixiaowen/diffuEraser) or [ModelScope](https://www.modelscope.cn/xingzi/diffuEraser.git) to the `weights` folder. +2. Download pretrained weight of based models and other components: + - [stable-diffusion-v1-5](https://huggingface.co/stable-diffusion-v1-5/stable-diffusion-v1-5) . The full folder size is over 30 GB. If you want to save storage space, you can download only the necessary files: feature_extractor, model_index.json, safety_checker, scheduler, text_encoder, and tokenizer,about 4GB. + - [PCM_Weights](https://huggingface.co/wangfuyun/PCM_Weights) + - [propainter](https://github.com/sczhou/ProPainter/releases/tag/v0.1.0) + - [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse) + + +The directory structure will be arranged as: +``` +weights + |- diffuEraser + |-brushnet + |-unet_main + |- stable-diffusion-v1-5 + |-feature_extractor + |-... + |- PCM_Weights + |-sd15 + |- propainter + |-ProPainter.pth + |-raft-things.pth + |-recurrent_flow_completion.pth + |- sd-vae-ft-mse + |-diffusion_pytorch_model.bin + |-... + |- README.md +``` + +#### Main Inference +We provide some examples in the [`examples`](./examples) folder. +Run the following commands to try it out: +```shell +cd DiffuEraser +python run_diffueraser.py +``` +The results will be saved in the `results` folder. +To test your own videos, please replace the `input_video` and `input_mask` in run_diffueraser.py . The first inference may take a long time. + +The `frame rate` of input_video and input_mask needs to be consistent. We currently only support `mp4 video` as input intead of split frames, you can convert frames to video using ffmepg: +```shell +ffmpeg -i image%03d.jpg -c:v libx264 -r 25 output.mp4 +``` +Notice: Do not convert the frame rate of mask video if it is not consitent with that of the input video, which would lead to errors due to misalignment. + + +Blow shows the estimated GPU memory requirements and inference time for different resolution: + +| Resolution | Gpu Memeory | Inference Time(250f(~10s), L20) | +| :--------- | :---------: | :-----------------------------: | +| 1280 x 720 | 33G | 314s | +| 960 x 540 | 20G | 175s | +| 640 x 360 | 12G | 92s | + + +## Citation + + If you find our repo useful for your research, please consider citing our paper: + + ```bibtex + @misc{li2025diffueraserdiffusionmodelvideo, + title={DiffuEraser: A Diffusion Model for Video Inpainting}, + author={Xiaowen Li and Haolan Xue and Peiran Ren and Liefeng Bo}, + year={2025}, + eprint={2501.10018}, + archivePrefix={arXiv}, + primaryClass={cs.CV}, + url={https://arxiv.org/abs/2501.10018}, +} + ``` + + +## License +This repository uses [Propainter](https://github.com/sczhou/ProPainter) as the prior model. Users must comply with [Propainter's license](https://github.com/sczhou/ProPainter/blob/main/LICENSE) when using this code. Or you can use other model to replace it. + +This project is licensed under the [Apache License Version 2.0](./LICENSE) except for the third-party components listed below. + + +## Acknowledgement + +This code is based on [BrushNet](https://github.com/TencentARC/BrushNet), [Propainter](https://github.com/sczhou/ProPainter) and [Animatediff](https://github.com/guoyww/AnimateDiff). The example videos come from [Pexels](https://www.pexels.com/), [DAVIS](https://davischallenge.org/), [SA-V](https://ai.meta.com/datasets/segment-anything-video) and [DanceTrack](https://dancetrack.github.io/). Thanks for their awesome works. + + diff --git a/assets/DiffuEraser_pipeline.png b/assets/DiffuEraser_pipeline.png new file mode 100644 index 0000000000000000000000000000000000000000..ab87700db6ffc786aee5f205cb18c1eed158bb4f Binary files /dev/null and b/assets/DiffuEraser_pipeline.png differ diff --git a/diffueraser/diffueraser.py b/diffueraser/diffueraser.py new file mode 100644 index 0000000000000000000000000000000000000000..108e72a0f360e375aaa1e0fe74ca71d3a90260cc --- /dev/null +++ b/diffueraser/diffueraser.py @@ -0,0 +1,432 @@ +import gc +import copy +import cv2 +import os +import numpy as np +import torch +import torchvision +from einops import repeat +from PIL import Image, ImageFilter +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + UniPCMultistepScheduler, + LCMScheduler, +) +from diffusers.schedulers import TCDScheduler +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.utils.torch_utils import randn_tensor +from transformers import AutoTokenizer, PretrainedConfig + +from libs.unet_motion_model import MotionAdapter, UNetMotionModel +from libs.brushnet_CA import BrushNetModel +from libs.unet_2d_condition import UNet2DConditionModel +from diffueraser.pipeline_diffueraser import StableDiffusionDiffuEraserPipeline + + +checkpoints = { + "2-Step": ["pcm_{}_smallcfg_2step_converted.safetensors", 2, 0.0], + "4-Step": ["pcm_{}_smallcfg_4step_converted.safetensors", 4, 0.0], + "8-Step": ["pcm_{}_smallcfg_8step_converted.safetensors", 8, 0.0], + "16-Step": ["pcm_{}_smallcfg_16step_converted.safetensors", 16, 0.0], + "Normal CFG 4-Step": ["pcm_{}_normalcfg_4step_converted.safetensors", 4, 7.5], + "Normal CFG 8-Step": ["pcm_{}_normalcfg_8step_converted.safetensors", 8, 7.5], + "Normal CFG 16-Step": ["pcm_{}_normalcfg_16step_converted.safetensors", 16, 7.5], + "LCM-Like LoRA": [ + "pcm_{}_lcmlike_lora_converted.safetensors", + 4, + 0.0, + ], +} + +def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str): + text_encoder_config = PretrainedConfig.from_pretrained( + pretrained_model_name_or_path, + subfolder="text_encoder", + revision=revision, + ) + model_class = text_encoder_config.architectures[0] + + if model_class == "CLIPTextModel": + from transformers import CLIPTextModel + + return CLIPTextModel + elif model_class == "RobertaSeriesModelWithTransformation": + from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation + + return RobertaSeriesModelWithTransformation + else: + raise ValueError(f"{model_class} is not supported.") + +def resize_frames(frames, size=None): + if size is not None: + out_size = size + process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8) + frames = [f.resize(process_size) for f in frames] + else: + out_size = frames[0].size + process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8) + if not out_size == process_size: + frames = [f.resize(process_size) for f in frames] + + return frames + +def read_mask(validation_mask, fps, n_total_frames, img_size, mask_dilation_iter, frames): + cap = cv2.VideoCapture(validation_mask) + if not cap.isOpened(): + print("Error: Could not open mask video.") + exit() + mask_fps = cap.get(cv2.CAP_PROP_FPS) + if mask_fps != fps: + cap.release() + raise ValueError("The frame rate of all input videos needs to be consistent.") + + masks = [] + masked_images = [] + idx = 0 + while True: + ret, frame = cap.read() + if not ret: + break + if(idx >= n_total_frames): + break + mask = Image.fromarray(frame[...,::-1]).convert('L') + if mask.size != img_size: + mask = mask.resize(img_size, Image.NEAREST) + mask = np.asarray(mask) + m = np.array(mask > 0).astype(np.uint8) + m = cv2.erode(m, + cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)), + iterations=1) + m = cv2.dilate(m, + cv2.getStructuringElement(cv2.MORPH_RECT, (3, 3)), + iterations=mask_dilation_iter) + + mask = Image.fromarray(m * 255) + masks.append(mask) + + masked_image = np.array(frames[idx])*(1-(np.array(mask)[:,:,np.newaxis].astype(np.float32)/255)) + masked_image = Image.fromarray(masked_image.astype(np.uint8)) + masked_images.append(masked_image) + + idx += 1 + cap.release() + + return masks, masked_images + +def read_priori(priori, fps, n_total_frames, img_size): + cap = cv2.VideoCapture(priori) + if not cap.isOpened(): + print("Error: Could not open video.") + exit() + priori_fps = cap.get(cv2.CAP_PROP_FPS) + if priori_fps != fps: + cap.release() + raise ValueError("The frame rate of all input videos needs to be consistent.") + + prioris=[] + idx = 0 + while True: + ret, frame = cap.read() + if not ret: + break + if(idx >= n_total_frames): + break + img = Image.fromarray(frame[...,::-1]) + if img.size != img_size: + img = img.resize(img_size) + prioris.append(img) + idx += 1 + cap.release() + + os.remove(priori) # remove priori + + return prioris + +def read_video(validation_image, video_length, nframes, max_img_size): + vframes, aframes, info = torchvision.io.read_video(filename=validation_image, pts_unit='sec', end_pts=video_length) # RGB + fps = info['video_fps'] + n_total_frames = int(video_length * fps) + n_clip = int(np.ceil(n_total_frames/nframes)) + + frames = list(vframes.numpy())[:n_total_frames] + frames = [Image.fromarray(f) for f in frames] + max_size = max(frames[0].size) + if(max_size<256): + raise ValueError("The resolution of the uploaded video must be larger than 256x256.") + if(max_size>4096): + raise ValueError("The resolution of the uploaded video must be smaller than 4096x4096.") + if max_size>max_img_size: + ratio = max_size/max_img_size + ratio_size = (int(frames[0].size[0]/ratio),int(frames[0].size[1]/ratio)) + img_size = (ratio_size[0]-ratio_size[0]%8, ratio_size[1]-ratio_size[1]%8) + resize_flag=True + elif (frames[0].size[0]%8==0) and (frames[0].size[1]%8==0): + img_size = frames[0].size + resize_flag=False + else: + ratio_size = frames[0].size + img_size = (ratio_size[0]-ratio_size[0]%8, ratio_size[1]-ratio_size[1]%8) + resize_flag=True + if resize_flag: + frames = resize_frames(frames, img_size) + img_size = frames[0].size + + return frames, fps, img_size, n_clip, n_total_frames + + +class DiffuEraser: + def __init__( + self, device, base_model_path, vae_path, diffueraser_path, revision=None, + ckpt="Normal CFG 4-Step", mode="sd15", loaded=None): + self.device = device + + ## load model + self.vae = AutoencoderKL.from_pretrained(vae_path) + self.noise_scheduler = DDPMScheduler.from_pretrained(base_model_path, + subfolder="scheduler", + prediction_type="v_prediction", + timestep_spacing="trailing", + rescale_betas_zero_snr=True + ) + self.tokenizer = AutoTokenizer.from_pretrained( + base_model_path, + subfolder="tokenizer", + use_fast=False, + ) + text_encoder_cls = import_model_class_from_model_name_or_path(base_model_path,revision) + self.text_encoder = text_encoder_cls.from_pretrained( + base_model_path, subfolder="text_encoder" + ) + self.brushnet = BrushNetModel.from_pretrained(diffueraser_path, subfolder="brushnet") + self.unet_main = UNetMotionModel.from_pretrained( + diffueraser_path, subfolder="unet_main", + ) + + ## set pipeline + self.pipeline = StableDiffusionDiffuEraserPipeline.from_pretrained( + base_model_path, + vae=self.vae, + text_encoder=self.text_encoder, + tokenizer=self.tokenizer, + unet=self.unet_main, + brushnet=self.brushnet + ).to(self.device, torch.float16) + self.pipeline.scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config) + self.pipeline.set_progress_bar_config(disable=True) + + self.noise_scheduler = UniPCMultistepScheduler.from_config(self.pipeline.scheduler.config) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + + ## use PCM + self.ckpt = ckpt + PCM_ckpts = checkpoints[ckpt][0].format(mode) + self.guidance_scale = checkpoints[ckpt][2] + if loaded != (ckpt + mode): + self.pipeline.load_lora_weights( + "weights/PCM_Weights", weight_name=PCM_ckpts, subfolder=mode + ) + loaded = ckpt + mode + + if ckpt == "LCM-Like LoRA": + self.pipeline.scheduler = LCMScheduler() + else: + self.pipeline.scheduler = TCDScheduler( + num_train_timesteps=1000, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + timestep_spacing="trailing", + ) + self.num_inference_steps = checkpoints[ckpt][1] + self.guidance_scale = 0 + + def forward(self, validation_image, validation_mask, priori, output_path, + max_img_size = 1280, video_length=2, mask_dilation_iter=4, + nframes=22, seed=None, revision = None, guidance_scale=None, blended=True): + validation_prompt = "" # + guidance_scale_final = self.guidance_scale if guidance_scale==None else guidance_scale + + if (max_img_size<256 or max_img_size>1920): + raise ValueError("The max_img_size must be larger than 256, smaller than 1920.") + + ################ read input video ################ + frames, fps, img_size, n_clip, n_total_frames = read_video(validation_image, video_length, nframes, max_img_size) + video_len = len(frames) + + ################ read mask ################ + validation_masks_input, validation_images_input = read_mask(validation_mask, fps, video_len, img_size, mask_dilation_iter, frames) + + ################ read priori ################ + prioris = read_priori(priori, fps, n_total_frames, img_size) + + ## recheck + n_total_frames = min(min(len(frames), len(validation_masks_input)), len(prioris)) + if(n_total_frames<22): + raise ValueError("The effective video duration is too short. Please make sure that the number of frames of video, mask, and priori is at least greater than 22 frames.") + validation_masks_input = validation_masks_input[:n_total_frames] + validation_images_input = validation_images_input[:n_total_frames] + frames = frames[:n_total_frames] + prioris = prioris[:n_total_frames] + + prioris = resize_frames(prioris) + validation_masks_input = resize_frames(validation_masks_input) + validation_images_input = resize_frames(validation_images_input) + resized_frames = resize_frames(frames) + + ############################################## + # DiffuEraser inference + ############################################## + print("DiffuEraser inference...") + if seed is None: + generator = None + else: + generator = torch.Generator(device=self.device).manual_seed(seed) + + ## random noise + real_video_length = len(validation_images_input) + tar_width, tar_height = validation_images_input[0].size + shape = ( + nframes, + 4, + tar_height//8, + tar_width//8 + ) + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet_main is not None: + prompt_embeds_dtype = self.unet_main.dtype + else: + prompt_embeds_dtype = torch.float16 + noise_pre = randn_tensor(shape, device=torch.device(self.device), dtype=prompt_embeds_dtype, generator=generator) + noise = repeat(noise_pre, "t c h w->(repeat t) c h w", repeat=n_clip)[:real_video_length,...] + + ################ prepare priori ################ + images_preprocessed = [] + for image in prioris: + image = self.image_processor.preprocess(image, height=tar_height, width=tar_width).to(dtype=torch.float32) + image = image.to(device=torch.device(self.device), dtype=torch.float16) + images_preprocessed.append(image) + pixel_values = torch.cat(images_preprocessed) + + with torch.no_grad(): + pixel_values = pixel_values.to(dtype=torch.float16) + latents = [] + num=4 + for i in range(0, pixel_values.shape[0], num): + latents.append(self.vae.encode(pixel_values[i : i + num]).latent_dist.sample()) + latents = torch.cat(latents, dim=0) + latents = latents * self.vae.config.scaling_factor #[(b f), c1, h, w], c1=4 + torch.cuda.empty_cache() + timesteps = torch.tensor([0], device=self.device) + timesteps = timesteps.long() + + validation_masks_input_ori = copy.deepcopy(validation_masks_input) + resized_frames_ori = copy.deepcopy(resized_frames) + ################ Pre-inference ################ + if n_total_frames > nframes*2: ## do pre-inference only when number of input frames is larger than nframes*2 + ## sample + step = n_total_frames / nframes + sample_index = [int(i * step) for i in range(nframes)] + sample_index = sample_index[:22] + validation_masks_input_pre = [validation_masks_input[i] for i in sample_index] + validation_images_input_pre = [validation_images_input[i] for i in sample_index] + latents_pre = torch.stack([latents[i] for i in sample_index]) + + ## add proiri + noisy_latents_pre = self.noise_scheduler.add_noise(latents_pre, noise_pre, timesteps) + latents_pre = noisy_latents_pre + + with torch.no_grad(): + latents_pre_out = self.pipeline( + num_frames=nframes, + prompt=validation_prompt, + images=validation_images_input_pre, + masks=validation_masks_input_pre, + num_inference_steps=self.num_inference_steps, + generator=generator, + guidance_scale=guidance_scale_final, + latents=latents_pre, + ).latents + torch.cuda.empty_cache() + + def decode_latents(latents, weight_dtype): + latents = 1 / self.vae.config.scaling_factor * latents + video = [] + for t in range(latents.shape[0]): + video.append(self.vae.decode(latents[t:t+1, ...].to(weight_dtype)).sample) + video = torch.concat(video, dim=0) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + with torch.no_grad(): + video_tensor_temp = decode_latents(latents_pre_out, weight_dtype=torch.float16) + images_pre_out = self.image_processor.postprocess(video_tensor_temp, output_type="pil") + torch.cuda.empty_cache() + + ## replace input frames with updated frames + black_image = Image.new('L', validation_masks_input[0].size, color=0) + for i,index in enumerate(sample_index): + latents[index] = latents_pre_out[i] + validation_masks_input[index] = black_image + validation_images_input[index] = images_pre_out[i] + resized_frames[index] = images_pre_out[i] + else: + latents_pre_out=None + sample_index=None + gc.collect() + torch.cuda.empty_cache() + + ################ Frame-by-frame inference ################ + ## add priori + noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) + latents = noisy_latents + with torch.no_grad(): + images = self.pipeline( + num_frames=nframes, + prompt=validation_prompt, + images=validation_images_input, + masks=validation_masks_input, + num_inference_steps=self.num_inference_steps, + generator=generator, + guidance_scale=guidance_scale_final, + latents=latents, + ).frames + images = images[:real_video_length] + + gc.collect() + torch.cuda.empty_cache() + + ################ Compose ################ + binary_masks = validation_masks_input_ori + mask_blurreds = [] + if blended: + # blur, you can adjust the parameters for better performance + for i in range(len(binary_masks)): + mask_blurred = cv2.GaussianBlur(np.array(binary_masks[i]), (21, 21), 0)/255. + binary_mask = 1-(1-np.array(binary_masks[i])/255.) * (1-mask_blurred) + mask_blurreds.append(Image.fromarray((binary_mask*255).astype(np.uint8))) + binary_masks = mask_blurreds + + comp_frames = [] + for i in range(len(images)): + mask = np.expand_dims(np.array(binary_masks[i]),2).repeat(3, axis=2).astype(np.float32)/255. + img = (np.array(images[i]).astype(np.uint8) * mask \ + + np.array(resized_frames_ori[i]).astype(np.uint8) * (1 - mask)).astype(np.uint8) + comp_frames.append(Image.fromarray(img)) + + default_fps = fps + writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), + default_fps, comp_frames[0].size) + for f in range(real_video_length): + img = np.array(comp_frames[f]).astype(np.uint8) + writer.write(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + writer.release() + ################################ + + return output_path + + + + diff --git a/diffueraser/pipeline_diffueraser.py b/diffueraser/pipeline_diffueraser.py new file mode 100644 index 0000000000000000000000000000000000000000..db4e230e8a7c627651cd8319a31c71823f20f390 --- /dev/null +++ b/diffueraser/pipeline_diffueraser.py @@ -0,0 +1,1349 @@ +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union +import numpy as np +import PIL.Image +from einops import rearrange, repeat +from dataclasses import dataclass +import copy +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection + +from diffusers.image_processor import PipelineImageInput, VaeImageProcessor +from diffusers.loaders import FromSingleFileMixin, IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, ImageProjection +from diffusers.models.lora import adjust_lora_scale_text_encoder +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + USE_PEFT_BACKEND, + deprecate, + logging, + replace_example_docstring, + scale_lora_layers, + unscale_lora_layers, + BaseOutput +) +from diffusers.utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin +from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker +from diffusers import ( + AutoencoderKL, + DDPMScheduler, + UniPCMultistepScheduler, +) + +from libs.unet_2d_condition import UNet2DConditionModel +from libs.brushnet_CA import BrushNetModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + +def get_frames_context_swap(total_frames=192, overlap=4, num_frames_per_clip=24): + if total_framesnum_frames_per_clip: + ## [0,num_frames_per_clip-1], [num_frames_per_clip, 2*num_frames_per_clip-1].... + for k in range(0,n-num_frames_per_clip,num_frames_per_clip-overlap): + context_list.append(sample_interval[k:k+num_frames_per_clip]) + if k+num_frames_per_clip < n and i==1: + context_list.append(sample_interval[n-num_frames_per_clip:n]) + context_list_swap.append(sample_interval[0:num_frames_per_clip]) + for k in range(num_frames_per_clip//2, n-num_frames_per_clip, num_frames_per_clip-overlap): + context_list_swap.append(sample_interval[k:k+num_frames_per_clip]) + if k+num_frames_per_clip < n and i==1: + context_list_swap.append(sample_interval[n-num_frames_per_clip:n]) + if n==num_frames_per_clip: + context_list.append(sample_interval[n-num_frames_per_clip:n]) + context_list_swap.append(sample_interval[n-num_frames_per_clip:n]) + return context_list, context_list_swap + +@dataclass +class DiffuEraserPipelineOutput(BaseOutput): + frames: Union[torch.Tensor, np.ndarray] + latents: Union[torch.Tensor, np.ndarray] + +class StableDiffusionDiffuEraserPipeline( + DiffusionPipeline, + StableDiffusionMixin, + TextualInversionLoaderMixin, + LoraLoaderMixin, + IPAdapterMixin, + FromSingleFileMixin, +): + r""" + Pipeline for video inpainting using Video Diffusion Model with BrushNet guidance. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods + implemented for all pipelines (downloading, saving, running on a particular device, etc.). + + The pipeline also inherits the following loading methods: + - [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings + - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights + - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights + - [`~loaders.FromSingleFileMixin.from_single_file`] for loading `.ckpt` files + - [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations. + text_encoder ([`~transformers.CLIPTextModel`]): + Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)). + tokenizer ([`~transformers.CLIPTokenizer`]): + A `CLIPTokenizer` to tokenize text. + unet ([`UNet2DConditionModel`]): + A `UNet2DConditionModel` to denoise the encoded image latents. + brushnet ([`BrushNetModel`]`): + Provides additional conditioning to the `unet` during the denoising process. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for more details + about a model's potential harms. + feature_extractor ([`~transformers.CLIPImageProcessor`]): + A `CLIPImageProcessor` to extract features from generated images; used as inputs to the `safety_checker`. + """ + + model_cpu_offload_seq = "text_encoder->image_encoder->unet->vae" + _optional_components = ["safety_checker", "feature_extractor", "image_encoder"] + _exclude_from_cpu_offload = ["safety_checker"] + _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + brushnet: BrushNetModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPImageProcessor, + image_encoder: CLIPVisionModelWithProjection = None, + requires_safety_checker: bool = True, + ): + super().__init__() + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + brushnet=brushnet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + image_encoder=image_encoder, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + **kwargs, + ): + deprecation_message = "`_encode_prompt()` is deprecated and it will be removed in a future version. Use `encode_prompt()` instead. Also, be aware that the output format changed from a concatenated tensor to a tuple." + deprecate("_encode_prompt()", "1.0.0", deprecation_message, standard_warn=False) + + prompt_embeds_tuple = self.encode_prompt( + prompt=prompt, + device=device, + num_images_per_prompt=num_images_per_prompt, + do_classifier_free_guidance=do_classifier_free_guidance, + negative_prompt=negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=lora_scale, + **kwargs, + ) + + # concatenate for backwards comp + prompt_embeds = torch.cat([prompt_embeds_tuple[1], prompt_embeds_tuple[0]]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_prompt + def encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + lora_scale (`float`, *optional*): + A LoRA scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + # set lora scale so that monkey patched LoRA + # function of text encoder can correctly access it + if lora_scale is not None and isinstance(self, LoraLoaderMixin): + self._lora_scale = lora_scale + + # dynamically adjust the LoRA scale + if not USE_PEFT_BACKEND: + adjust_lora_scale_text_encoder(self.text_encoder, lora_scale) + else: + scale_lora_layers(self.text_encoder, lora_scale) + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + if clip_skip is None: + prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=attention_mask) + prompt_embeds = prompt_embeds[0] + else: + prompt_embeds = self.text_encoder( + text_input_ids.to(device), attention_mask=attention_mask, output_hidden_states=True + ) + # Access the `hidden_states` first, that contains a tuple of + # all the hidden states from the encoder layers. Then index into + # the tuple to access the hidden states from the desired layer. + prompt_embeds = prompt_embeds[-1][-(clip_skip + 1)] + # We also need to apply the final LayerNorm here to not mess with the + # representations. The `last_hidden_states` that we typically use for + # obtaining the final prompt representations passes through the LayerNorm + # layer. + prompt_embeds = self.text_encoder.text_model.final_layer_norm(prompt_embeds) + + if self.text_encoder is not None: + prompt_embeds_dtype = self.text_encoder.dtype + elif self.unet is not None: + prompt_embeds_dtype = self.unet.dtype + else: + prompt_embeds_dtype = prompt_embeds.dtype + + prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif prompt is not None and type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: process multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + if isinstance(self, LoraLoaderMixin) and USE_PEFT_BACKEND: + # Retrieve the original scale by scaling back the LoRA layers + unscale_lora_layers(self.text_encoder, lora_scale) + + return prompt_embeds, negative_prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.encode_image + def encode_image(self, image, device, num_images_per_prompt, output_hidden_states=None): + dtype = next(self.image_encoder.parameters()).dtype + + if not isinstance(image, torch.Tensor): + image = self.feature_extractor(image, return_tensors="pt").pixel_values + + image = image.to(device=device, dtype=dtype) + if output_hidden_states: + image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2] + image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_enc_hidden_states = self.image_encoder( + torch.zeros_like(image), output_hidden_states=True + ).hidden_states[-2] + uncond_image_enc_hidden_states = uncond_image_enc_hidden_states.repeat_interleave( + num_images_per_prompt, dim=0 + ) + return image_enc_hidden_states, uncond_image_enc_hidden_states + else: + image_embeds = self.image_encoder(image).image_embeds + image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0) + uncond_image_embeds = torch.zeros_like(image_embeds) + + return image_embeds, uncond_image_embeds + + def decode_latents(self, latents, weight_dtype): + latents = 1 / self.vae.config.scaling_factor * latents + video = [] + for t in range(latents.shape[0]): + video.append(self.vae.decode(latents[t:t+1, ...].to(weight_dtype)).sample) + video = torch.concat(video, dim=0) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance + ): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack( + [single_negative_image_embeds] * num_images_per_prompt, dim=0 + ) + + if do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + else: + repeat_dims = [1] + image_embeds = [] + for single_image_embeds in ip_adapter_image_embeds: + if do_classifier_free_guidance: + single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2) + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + single_negative_image_embeds = single_negative_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:])) + ) + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + else: + single_image_embeds = single_image_embeds.repeat( + num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:])) + ) + image_embeds.append(single_image_embeds) + + return image_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is None: + has_nsfw_concept = None + else: + if torch.is_tensor(image): + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") + else: + feature_extractor_input = self.image_processor.numpy_to_pil(image) + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents + def decode_latents(self, latents, weight_dtype): + latents = 1 / self.vae.config.scaling_factor * latents + video = [] + for t in range(latents.shape[0]): + video.append(self.vae.decode(latents[t:t+1, ...].to(weight_dtype)).sample) + video = torch.concat(video, dim=0) + + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + video = video.float() + return video + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + images, + masks, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ip_adapter_image=None, + ip_adapter_image_embeds=None, + brushnet_conditioning_scale=1.0, + control_guidance_start=0.0, + control_guidance_end=1.0, + callback_on_step_end_tensor_inputs=None, + ): + if callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if callback_on_step_end_tensor_inputs is not None and not all( + k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs + ): + raise ValueError( + f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + # Check `image` + is_compiled = hasattr(F, "scaled_dot_product_attention") and isinstance( + self.brushnet, torch._dynamo.eval_frame.OptimizedModule + ) + if ( + isinstance(self.brushnet, BrushNetModel) + or is_compiled + and isinstance(self.brushnet._orig_mod, BrushNetModel) + ): + self.check_image(images, masks, prompt, prompt_embeds) + else: + assert False + + # Check `brushnet_conditioning_scale` + if ( + isinstance(self.brushnet, BrushNetModel) + or is_compiled + and isinstance(self.brushnet._orig_mod, BrushNetModel) + ): + if not isinstance(brushnet_conditioning_scale, float): + raise TypeError("For single brushnet: `brushnet_conditioning_scale` must be type `float`.") + else: + assert False + + if not isinstance(control_guidance_start, (tuple, list)): + control_guidance_start = [control_guidance_start] + + if not isinstance(control_guidance_end, (tuple, list)): + control_guidance_end = [control_guidance_end] + + if len(control_guidance_start) != len(control_guidance_end): + raise ValueError( + f"`control_guidance_start` has {len(control_guidance_start)} elements, but `control_guidance_end` has {len(control_guidance_end)} elements. Make sure to provide the same number of elements to each list." + ) + + for start, end in zip(control_guidance_start, control_guidance_end): + if start >= end: + raise ValueError( + f"control guidance start: {start} cannot be larger or equal to control guidance end: {end}." + ) + if start < 0.0: + raise ValueError(f"control guidance start: {start} can't be smaller than 0.") + if end > 1.0: + raise ValueError(f"control guidance end: {end} can't be larger than 1.0.") + + if ip_adapter_image is not None and ip_adapter_image_embeds is not None: + raise ValueError( + "Provide either `ip_adapter_image` or `ip_adapter_image_embeds`. Cannot leave both `ip_adapter_image` and `ip_adapter_image_embeds` defined." + ) + + if ip_adapter_image_embeds is not None: + if not isinstance(ip_adapter_image_embeds, list): + raise ValueError( + f"`ip_adapter_image_embeds` has to be of type `list` but is {type(ip_adapter_image_embeds)}" + ) + elif ip_adapter_image_embeds[0].ndim not in [3, 4]: + raise ValueError( + f"`ip_adapter_image_embeds` has to be a list of 3D or 4D tensors but is {ip_adapter_image_embeds[0].ndim}D" + ) + + def check_image(self, images, masks, prompt, prompt_embeds): + for image in images: + image_is_pil = isinstance(image, PIL.Image.Image) + image_is_tensor = isinstance(image, torch.Tensor) + image_is_np = isinstance(image, np.ndarray) + image_is_pil_list = isinstance(image, list) and isinstance(image[0], PIL.Image.Image) + image_is_tensor_list = isinstance(image, list) and isinstance(image[0], torch.Tensor) + image_is_np_list = isinstance(image, list) and isinstance(image[0], np.ndarray) + + if ( + not image_is_pil + and not image_is_tensor + and not image_is_np + and not image_is_pil_list + and not image_is_tensor_list + and not image_is_np_list + ): + raise TypeError( + f"image must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(image)}" + ) + for mask in masks: + mask_is_pil = isinstance(mask, PIL.Image.Image) + mask_is_tensor = isinstance(mask, torch.Tensor) + mask_is_np = isinstance(mask, np.ndarray) + mask_is_pil_list = isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image) + mask_is_tensor_list = isinstance(mask, list) and isinstance(mask[0], torch.Tensor) + mask_is_np_list = isinstance(mask, list) and isinstance(mask[0], np.ndarray) + + if ( + not mask_is_pil + and not mask_is_tensor + and not mask_is_np + and not mask_is_pil_list + and not mask_is_tensor_list + and not mask_is_np_list + ): + raise TypeError( + f"mask must be passed and be one of PIL image, numpy array, torch tensor, list of PIL images, list of numpy arrays or list of torch tensors, but is {type(mask)}" + ) + + if image_is_pil: + image_batch_size = 1 + else: + image_batch_size = len(image) + + if prompt is not None and isinstance(prompt, str): + prompt_batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + prompt_batch_size = len(prompt) + elif prompt_embeds is not None: + prompt_batch_size = prompt_embeds.shape[0] + + if image_batch_size != 1 and image_batch_size != prompt_batch_size: + raise ValueError( + f"If image batch size is not 1, image batch size must be same as prompt batch size. image batch size: {image_batch_size}, prompt batch size: {prompt_batch_size}" + ) + + def prepare_image( + self, + images, + width, + height, + batch_size, + num_images_per_prompt, + device, + dtype, + do_classifier_free_guidance=False, + guess_mode=False, + ): + images_new = [] + for image in images: + image = self.image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32) + image_batch_size = image.shape[0] + + if image_batch_size == 1: + repeat_by = batch_size + else: + # image batch size is the same as prompt batch size + repeat_by = num_images_per_prompt + + image = image.repeat_interleave(repeat_by, dim=0) + + image = image.to(device=device, dtype=dtype) + + # if do_classifier_free_guidance and not guess_mode: + # image = torch.cat([image] * 2) + images_new.append(image) + + return images_new + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents(self, batch_size, num_channels_latents, num_frames, height, width, dtype, device, generator, latents=None): + # shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + #b,c,n,h,w + shape = ( + batch_size, + num_channels_latents, + num_frames, + height // self.vae_scale_factor, + width // self.vae_scale_factor + ) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + # noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + noise = rearrange(randn_tensor(shape, generator=generator, device=device, dtype=dtype), "b c t h w -> (b t) c h w") + else: + noise = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = noise * self.scheduler.init_noise_sigma + return latents, noise + + @staticmethod + def temp_blend(a, b, overlap): + factor = torch.arange(overlap).to(b.device).view(overlap, 1, 1, 1) / (overlap - 1) + a[:overlap, ...] = (1 - factor) * a[:overlap, ...] + factor * b[:overlap, ...] + a[overlap:, ...] = b[overlap:, ...] + return a + + # Copied from diffusers.pipelines.latent_consistency_models.pipeline_latent_consistency_text2img.LatentConsistencyModelPipeline.get_guidance_scale_embedding + def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32): + """ + See https://github.com/google-research/vdm/blob/dc27b98a554f65cdc654b800da5aa1846545d41b/model_vdm.py#L298 + + Args: + timesteps (`torch.Tensor`): + generate embedding vectors at these timesteps + embedding_dim (`int`, *optional*, defaults to 512): + dimension of the embeddings to generate + dtype: + data type of the generated embeddings + + Returns: + `torch.FloatTensor`: Embedding vectors with shape `(len(timesteps), embedding_dim)` + """ + assert len(w.shape) == 1 + w = w * 1000.0 + + half_dim = embedding_dim // 2 + emb = torch.log(torch.tensor(10000.0)) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=dtype) * -emb) + emb = w.to(dtype)[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1)) + assert emb.shape == (w.shape[0], embedding_dim) + return emb + + @property + def guidance_scale(self): + return self._guidance_scale + + @property + def clip_skip(self): + return self._clip_skip + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + @property + def do_classifier_free_guidance(self): + return self._guidance_scale > 1 and self.unet.config.time_cond_proj_dim is None + + @property + def cross_attention_kwargs(self): + return self._cross_attention_kwargs + + @property + def num_timesteps(self): + return self._num_timesteps + + # based on BrushNet: https://github.com/TencentARC/BrushNet/blob/main/src/diffusers/pipelines/brushnet/pipeline_brushnet.py + @torch.no_grad() + def __call__( + self, + num_frames: Optional[int] = 24, + prompt: Union[str, List[str]] = None, + images: PipelineImageInput = None, ##masked images + masks: PipelineImageInput = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_inference_steps: int = 50, + timesteps: List[int] = None, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + brushnet_conditioning_scale: Union[float, List[float]] = 1.0, + guess_mode: bool = False, + control_guidance_start: Union[float, List[float]] = 0.0, + control_guidance_end: Union[float, List[float]] = 1.0, + clip_skip: Optional[int] = None, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + callback_on_step_end_tensor_inputs: List[str] = ["latents"], + **kwargs, + ): + r""" + The call function to the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`. + image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The BrushNet branch input condition to provide guidance to the `unet` for generation. + mask (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`, `List[np.ndarray]`,: + `List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`): + The BrushNet branch input condition to provide guidance to the `unet` for generation. + height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument + in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is + passed will be used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 7.5): + A higher guidance scale value encourages the model to generate images closely linked to the text + `prompt` at the expense of lower image quality. Guidance scale is enabled when `guidance_scale > 1`. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide what to not include in image generation. If not defined, you need to + pass `negative_prompt_embeds` instead. Ignored when not using guidance (`guidance_scale < 1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) from the [DDIM](https://arxiv.org/abs/2010.02502) paper. Only applies + to the [`~schedulers.DDIMScheduler`], and is ignored in other schedulers. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make + generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor is generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs (prompt weighting). If not + provided, text embeddings are generated from the `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs (prompt weighting). If + not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. + ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of IP-adapters. + Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. It should contain the negative image embedding + if `do_classifier_free_guidance` is set to `True`. + If not provided, embeddings are computed from the `ip_adapter_image` input argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generated image. Choose between `PIL.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that calls every `callback_steps` steps during inference. The function is called with the + following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function is called. If not specified, the callback is called at + every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in + [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + brushnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0): + The outputs of the BrushNet are multiplied by `brushnet_conditioning_scale` before they are added + to the residual in the original `unet`. If multiple BrushNets are specified in `init`, you can set + the corresponding scale as a list. + guess_mode (`bool`, *optional*, defaults to `False`): + The BrushNet encoder tries to recognize the content of the input image even if you remove all + prompts. A `guidance_scale` value between 3.0 and 5.0 is recommended. + control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0): + The percentage of total steps at which the BrushNet starts applying. + control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0): + The percentage of total steps at which the BrushNet stops applying. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + callback_on_step_end_tensor_inputs (`List`, *optional*): + The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list + will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the + `._callback_tensor_inputs` attribute of your pipeine class. + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] is returned, + otherwise a `tuple` is returned where the first element is a list with the generated images and the + second element is a list of `bool`s indicating whether the corresponding generated image contains + "not-safe-for-work" (nsfw) content. + """ + + callback = kwargs.pop("callback", None) + callback_steps = kwargs.pop("callback_steps", None) + + if callback is not None: + deprecate( + "callback", + "1.0.0", + "Passing `callback` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + if callback_steps is not None: + deprecate( + "callback_steps", + "1.0.0", + "Passing `callback_steps` as an input argument to `__call__` is deprecated, consider using `callback_on_step_end`", + ) + + brushnet = self.brushnet._orig_mod if is_compiled_module(self.brushnet) else self.brushnet + + # align format for control guidance + if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list): + control_guidance_start = len(control_guidance_end) * [control_guidance_start] + elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list): + control_guidance_end = len(control_guidance_start) * [control_guidance_end] + elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list): + control_guidance_start, control_guidance_end = ( + [control_guidance_start], + [control_guidance_end], + ) + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, + images, + masks, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ip_adapter_image, + ip_adapter_image_embeds, + brushnet_conditioning_scale, + control_guidance_start, + control_guidance_end, + callback_on_step_end_tensor_inputs, + ) + + self._guidance_scale = guidance_scale + self._clip_skip = clip_skip + self._cross_attention_kwargs = cross_attention_kwargs + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + global_pool_conditions = ( + brushnet.config.global_pool_conditions + if isinstance(brushnet, BrushNetModel) + else brushnet.nets[0].config.global_pool_conditions + ) + guess_mode = guess_mode or global_pool_conditions + video_length = len(images) + + # 3. Encode input prompt + text_encoder_lora_scale = ( + self.cross_attention_kwargs.get("scale", None) if self.cross_attention_kwargs is not None else None + ) + prompt_embeds, negative_prompt_embeds = self.encode_prompt( + prompt, + device, + num_images_per_prompt, + self.do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + lora_scale=text_encoder_lora_scale, + clip_skip=self.clip_skip, + ) + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + if self.do_classifier_free_guidance: + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, + ip_adapter_image_embeds, + device, + batch_size * num_images_per_prompt, + self.do_classifier_free_guidance, + ) + + # 4. Prepare image + if isinstance(brushnet, BrushNetModel): + images = self.prepare_image( + images=images, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=brushnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + original_masks = self.prepare_image( + images=masks, + width=width, + height=height, + batch_size=batch_size * num_images_per_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + dtype=brushnet.dtype, + do_classifier_free_guidance=self.do_classifier_free_guidance, + guess_mode=guess_mode, + ) + original_masks_new = [] + for original_mask in original_masks: + original_mask=(original_mask.sum(1)[:,None,:,:] < 0).to(images[0].dtype) + original_masks_new.append(original_mask) + original_masks = original_masks_new + + height, width = images[0].shape[-2:] + else: + assert False + + # 5. Prepare timesteps + timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) + self._num_timesteps = len(timesteps) + + # 6. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents, noise = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + num_frames, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6.1 prepare condition latents + images = torch.cat(images) + images = images.to(dtype=images[0].dtype) + conditioning_latents = [] + num=4 + for i in range(0, images.shape[0], num): + conditioning_latents.append(self.vae.encode(images[i : i + num]).latent_dist.sample()) + conditioning_latents = torch.cat(conditioning_latents, dim=0) + + conditioning_latents = conditioning_latents * self.vae.config.scaling_factor #[(f c h w],c2=4 + + original_masks = torch.cat(original_masks) + masks = torch.nn.functional.interpolate( + original_masks, + size=( + latents.shape[-2], + latents.shape[-1] + ) + ) ##[ f c h w],c=1 + + conditioning_latents=torch.concat([conditioning_latents,masks],1) + + # 6.5 Optionally get Guidance Scale Embedding + timestep_cond = None + if self.unet.config.time_cond_proj_dim is not None: + guidance_scale_tensor = torch.tensor(self.guidance_scale - 1).repeat(batch_size * num_images_per_prompt) + timestep_cond = self.get_guidance_scale_embedding( + guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim + ).to(device=device, dtype=latents.dtype) + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7.1 Add image embeds for IP-Adapter + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) + + # 7.2 Create tensor stating which brushnets to keep + brushnet_keep = [] + for i in range(len(timesteps)): + keeps = [ + 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e) + for s, e in zip(control_guidance_start, control_guidance_end) + ] + brushnet_keep.append(keeps[0] if isinstance(brushnet, BrushNetModel) else keeps) + + + overlap = num_frames//4 + context_list, context_list_swap = get_frames_context_swap(video_length, overlap=overlap, num_frames_per_clip=num_frames) + scheduler_status = [copy.deepcopy(self.scheduler.__dict__)] * len(context_list) + scheduler_status_swap = [copy.deepcopy(self.scheduler.__dict__)] * len(context_list_swap) + count = torch.zeros_like(latents) + value = torch.zeros_like(latents) + + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + is_unet_compiled = is_compiled_module(self.unet) + is_brushnet_compiled = is_compiled_module(self.brushnet) + is_torch_higher_equal_2_1 = is_torch_version(">=", "2.1") + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + + count.zero_() + value.zero_() + ## swap + if (i%2==1): + context_list_choose = context_list_swap + scheduler_status_choose = scheduler_status_swap + else: + context_list_choose = context_list + scheduler_status_choose = scheduler_status + + + for j, context in enumerate(context_list_choose): + self.scheduler.__dict__.update(scheduler_status_choose[j]) + + latents_j = latents[context, :, :, :] + + # Relevant thread: + # https://dev-discuss.pytorch.org/t/cudagraphs-in-pytorch-2-0/1428 + if (is_unet_compiled and is_brushnet_compiled) and is_torch_higher_equal_2_1: + torch._inductor.cudagraph_mark_step_begin() + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents_j] * 2) if self.do_classifier_free_guidance else latents_j + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # brushnet(s) inference + if guess_mode and self.do_classifier_free_guidance: + # Infer BrushNet only for the conditional batch. + control_model_input = latents_j + control_model_input = self.scheduler.scale_model_input(control_model_input, t) + brushnet_prompt_embeds = prompt_embeds.chunk(2)[1] + brushnet_prompt_embeds = rearrange(repeat(brushnet_prompt_embeds, "b c d -> b t c d", t=num_frames), 'b t c d -> (b t) c d') + else: + control_model_input = latent_model_input + brushnet_prompt_embeds = prompt_embeds + if self.do_classifier_free_guidance: + neg_brushnet_prompt_embeds, brushnet_prompt_embeds = brushnet_prompt_embeds.chunk(2) + brushnet_prompt_embeds = rearrange(repeat(brushnet_prompt_embeds, "b c d -> b t c d", t=num_frames), 'b t c d -> (b t) c d') + neg_brushnet_prompt_embeds = rearrange(repeat(neg_brushnet_prompt_embeds, "b c d -> b t c d", t=num_frames), 'b t c d -> (b t) c d') + brushnet_prompt_embeds = torch.cat([neg_brushnet_prompt_embeds, brushnet_prompt_embeds]) + else: + brushnet_prompt_embeds = rearrange(repeat(brushnet_prompt_embeds, "b c d -> b t c d", t=num_frames), 'b t c d -> (b t) c d') + + if isinstance(brushnet_keep[i], list): + cond_scale = [c * s for c, s in zip(brushnet_conditioning_scale, brushnet_keep[i])] + else: + brushnet_cond_scale = brushnet_conditioning_scale + if isinstance(brushnet_cond_scale, list): + brushnet_cond_scale = brushnet_cond_scale[0] + cond_scale = brushnet_cond_scale * brushnet_keep[i] + + + down_block_res_samples, mid_block_res_sample, up_block_res_samples = self.brushnet( + control_model_input, + t, + encoder_hidden_states=brushnet_prompt_embeds, + brushnet_cond=torch.cat([conditioning_latents[context, :, :, :]]*2) if self.do_classifier_free_guidance else conditioning_latents[context, :, :, :], + conditioning_scale=cond_scale, + guess_mode=guess_mode, + return_dict=False, + ) + + if guess_mode and self.do_classifier_free_guidance: + # Infered BrushNet only for the conditional batch. + # To apply the output of BrushNet to both the unconditional and conditional batches, + # add 0 to the unconditional batch to keep it unchanged. + down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples] + mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample]) + up_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in up_block_res_samples] + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + timestep_cond=timestep_cond, + cross_attention_kwargs=self.cross_attention_kwargs, + down_block_add_samples=down_block_res_samples, + mid_block_add_sample=mid_block_res_sample, + up_block_add_samples=up_block_res_samples, + added_cond_kwargs=added_cond_kwargs, + return_dict=False, + num_frames=num_frames, + )[0] + + # perform guidance + if self.do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents_j = self.scheduler.step(noise_pred, t, latents_j, **extra_step_kwargs, return_dict=False)[0] + + count[context, ...] += 1 + + if j==0: + value[context, ...] += latents_j + else: + overlap_index_list = [index for index, value in enumerate(count[context, 0, 0, 0]) if value > 1] + overlap_cur = len(overlap_index_list) + ratio_next = torch.linspace(0, 1, overlap_cur+2)[1:-1] + ratio_pre = 1-ratio_next + for i_overlap in overlap_index_list: + value[context[i_overlap], ...] = value[context[i_overlap], ...]*ratio_pre[i_overlap] + latents_j[i_overlap, ...]*ratio_next[i_overlap] + value[context[i_overlap:num_frames], ...] = latents_j[i_overlap:num_frames, ...] + + latents = value.clone() + + if callback_on_step_end is not None: + callback_kwargs = {} + for k in callback_on_step_end_tensor_inputs: + callback_kwargs[k] = locals()[k] + callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) + + latents = callback_outputs.pop("latents", latents) + prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) + negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + step_idx = i // getattr(self.scheduler, "order", 1) + callback(step_idx, t, latents) + + + # If we do sequential model offloading, let's offload unet and brushnet + # manually for max memory savings + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.unet.to("cpu") + self.brushnet.to("cpu") + torch.cuda.empty_cache() + + if output_type == "latent": + image = latents + has_nsfw_concept = None + return DiffuEraserPipelineOutput(frames=image, nsfw_content_detected=has_nsfw_concept) + + video_tensor = self.decode_latents(latents, weight_dtype=prompt_embeds.dtype) + + if output_type == "pt": + video = video_tensor + else: + video = self.image_processor.postprocess(video_tensor, output_type=output_type) + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (video, has_nsfw_concept) + + return DiffuEraserPipelineOutput(frames=video, latents=latents) diff --git a/examples/example1/mask.mp4 b/examples/example1/mask.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..065f632cfe456ba808fa1c63fb7285fb52aef8b5 Binary files /dev/null and b/examples/example1/mask.mp4 differ diff --git a/examples/example1/video.mp4 b/examples/example1/video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..c9c3acec1b37146bb9379de7df714fe013f4806b Binary files /dev/null and b/examples/example1/video.mp4 differ diff --git a/examples/example2/mask.mp4 b/examples/example2/mask.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..463ccf51075591edd7e2d9f8b528c29d786b63d1 --- /dev/null +++ b/examples/example2/mask.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:39849531b31960ee023cd33caf402afd4a4c1402276ba8afa04b7888feb52c3f +size 1249680 diff --git a/examples/example2/video.mp4 b/examples/example2/video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..992b5fc234a255fc104f58f7bded5455b422f171 Binary files /dev/null and b/examples/example2/video.mp4 differ diff --git a/examples/example3/mask.mp4 b/examples/example3/mask.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..fe2ee61cb7930d81bc10baa3bb6d5983ae8940e9 Binary files /dev/null and b/examples/example3/mask.mp4 differ diff --git a/examples/example3/video.mp4 b/examples/example3/video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..4671b8fca23896f621d3f0426e18073392254e35 --- /dev/null +++ b/examples/example3/video.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b21c936a305f80ed6707bad621712b24bd1e7a69f82ec7cdd949b18fd1a7fd56 +size 5657081 diff --git a/libs/brushnet_CA.py b/libs/brushnet_CA.py new file mode 100644 index 0000000000000000000000000000000000000000..cab2b6bca90530697607109d26491378862cde32 --- /dev/null +++ b/libs/brushnet_CA.py @@ -0,0 +1,939 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +from torch import nn +from torch.nn import functional as F + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from diffusers.models.embeddings import TextImageProjection, TextImageTimeEmbedding, TextTimeEmbedding, TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from .unet_2d_blocks import ( + CrossAttnDownBlock2D, + DownBlock2D, + UNetMidBlock2D, + UNetMidBlock2DCrossAttn, + get_down_block, + get_mid_block, + get_up_block, + MidBlock2D +) + +# from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel +from libs.unet_2d_condition import UNet2DConditionModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class BrushNetOutput(BaseOutput): + """ + The output of [`BrushNetModel`]. + + Args: + up_block_res_samples (`tuple[torch.Tensor]`): + A tuple of upsample activations at different resolutions for each upsampling block. Each tensor should + be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be + used to condition the original UNet's upsampling activations. + down_block_res_samples (`tuple[torch.Tensor]`): + A tuple of downsample activations at different resolutions for each downsampling block. Each tensor should + be of shape `(batch_size, channel * resolution, height //resolution, width // resolution)`. Output can be + used to condition the original UNet's downsampling activations. + mid_down_block_re_sample (`torch.Tensor`): + The activation of the midde block (the lowest sample resolution). Each tensor should be of shape + `(batch_size, channel * lowest_resolution, height // lowest_resolution, width // lowest_resolution)`. + Output can be used to condition the original UNet's middle block activation. + """ + + up_block_res_samples: Tuple[torch.Tensor] + down_block_res_samples: Tuple[torch.Tensor] + mid_block_res_sample: torch.Tensor + + +class BrushNetModel(ModelMixin, ConfigMixin): + """ + A BrushNet model. + + Args: + in_channels (`int`, defaults to 4): + The number of channels in the input sample. + flip_sin_to_cos (`bool`, defaults to `True`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, defaults to 0): + The frequency shift to apply to the time embedding. + down_block_types (`tuple[str]`, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention (`Union[bool, Tuple[bool]]`, defaults to `False`): + block_out_channels (`tuple[int]`, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, defaults to 2): + The number of layers per block. + downsample_padding (`int`, defaults to 1): + The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, defaults to 1): + The scale factor to use for the mid block. + act_fn (`str`, defaults to "silu"): + The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the normalization. If None, normalization and activation layers is skipped + in post-processing. + norm_eps (`float`, defaults to 1e-5): + The epsilon to use for the normalization. + cross_attention_dim (`int`, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int` or `Tuple[int]`, *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`Union[int, Tuple[int]]`, defaults to 8): + The dimension of the attention heads. + use_linear_projection (`bool`, defaults to `False`): + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from None, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + num_class_embeds (`int`, *optional*, defaults to 0): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + upcast_attention (`bool`, defaults to `False`): + resnet_time_scale_shift (`str`, defaults to `"default"`): + Time scale shift config for ResNet blocks (see `ResnetBlock2D`). Choose from `default` or `scale_shift`. + projection_class_embeddings_input_dim (`int`, *optional*, defaults to `None`): + The dimension of the `class_labels` input when `class_embed_type="projection"`. Required when + `class_embed_type="projection"`. + brushnet_conditioning_channel_order (`str`, defaults to `"rgb"`): + The channel order of conditional image. Will convert to `rgb` if it's `bgr`. + conditioning_embedding_out_channels (`tuple[int]`, *optional*, defaults to `(16, 32, 96, 256)`): + The tuple of output channel for each block in the `conditioning_embedding` layer. + global_pool_conditions (`bool`, defaults to `False`): + TODO(Patrick) - unused parameter. + addition_embed_type_num_heads (`int`, defaults to 64): + The number of heads to use for the `TextTimeEmbedding` layer. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 4, + conditioning_channels: int = 5, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str, ...] = ( + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int, ...]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + projection_class_embeddings_input_dim: Optional[int] = None, + brushnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + global_pool_conditions: bool = False, + addition_embed_type_num_heads: int = 64, + ): + super().__init__() + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + # input + conv_in_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in_condition = nn.Conv2d( + in_channels+conditioning_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + self.down_blocks = nn.ModuleList([]) + self.brushnet_down_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + + brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + brushnet_block = zero_module(brushnet_block) + self.brushnet_down_blocks.append(brushnet_block) #零卷积 + + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + self.down_blocks.append(down_block) + + for _ in range(layers_per_block): + brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + brushnet_block = zero_module(brushnet_block) + self.brushnet_down_blocks.append(brushnet_block) #零卷积 + + if not is_final_block: + brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + brushnet_block = zero_module(brushnet_block) + self.brushnet_down_blocks.append(brushnet_block) + + # mid + mid_block_channel = block_out_channels[-1] + + brushnet_block = nn.Conv2d(mid_block_channel, mid_block_channel, kernel_size=1) + brushnet_block = zero_module(brushnet_block) + self.brushnet_mid_block = brushnet_block + + self.mid_block = get_mid_block( + mid_block_type, + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=mid_block_channel, + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_transformer_layers_per_block = (list(reversed(transformer_layers_per_block))) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + + self.up_blocks = nn.ModuleList([]) + self.brushnet_up_blocks = nn.ModuleList([]) + + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block+1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resolution_idx=i, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=reversed_num_attention_heads[i], + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + for _ in range(layers_per_block+1): + brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + brushnet_block = zero_module(brushnet_block) + self.brushnet_up_blocks.append(brushnet_block) + + if not is_final_block: + brushnet_block = nn.Conv2d(output_channel, output_channel, kernel_size=1) + brushnet_block = zero_module(brushnet_block) + self.brushnet_up_blocks.append(brushnet_block) + + + @classmethod + def from_unet( + cls, + unet: UNet2DConditionModel, + brushnet_conditioning_channel_order: str = "rgb", + conditioning_embedding_out_channels: Optional[Tuple[int, ...]] = (16, 32, 96, 256), + load_weights_from_unet: bool = True, + conditioning_channels: int = 5, + ): + r""" + Instantiate a [`BrushNetModel`] from [`UNet2DConditionModel`]. + + Parameters: + unet (`UNet2DConditionModel`): + The UNet model weights to copy to the [`BrushNetModel`]. All configuration options are also copied + where applicable. + """ + transformer_layers_per_block = ( + unet.config.transformer_layers_per_block if "transformer_layers_per_block" in unet.config else 1 + ) + encoder_hid_dim = unet.config.encoder_hid_dim if "encoder_hid_dim" in unet.config else None + encoder_hid_dim_type = unet.config.encoder_hid_dim_type if "encoder_hid_dim_type" in unet.config else None + addition_embed_type = unet.config.addition_embed_type if "addition_embed_type" in unet.config else None + addition_time_embed_dim = ( + unet.config.addition_time_embed_dim if "addition_time_embed_dim" in unet.config else None + ) + + brushnet = cls( + in_channels=unet.config.in_channels, + conditioning_channels=conditioning_channels, + flip_sin_to_cos=unet.config.flip_sin_to_cos, + freq_shift=unet.config.freq_shift, + # down_block_types=['DownBlock2D','DownBlock2D','DownBlock2D','DownBlock2D'], + down_block_types=[ + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ], + # mid_block_type='MidBlock2D', + mid_block_type="UNetMidBlock2DCrossAttn", + # up_block_types=['UpBlock2D','UpBlock2D','UpBlock2D','UpBlock2D'], + up_block_types=["UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"], + only_cross_attention=unet.config.only_cross_attention, + block_out_channels=unet.config.block_out_channels, + layers_per_block=unet.config.layers_per_block, + downsample_padding=unet.config.downsample_padding, + mid_block_scale_factor=unet.config.mid_block_scale_factor, + act_fn=unet.config.act_fn, + norm_num_groups=unet.config.norm_num_groups, + norm_eps=unet.config.norm_eps, + cross_attention_dim=unet.config.cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block, + encoder_hid_dim=encoder_hid_dim, + encoder_hid_dim_type=encoder_hid_dim_type, + attention_head_dim=unet.config.attention_head_dim, + num_attention_heads=unet.config.num_attention_heads, + use_linear_projection=unet.config.use_linear_projection, + class_embed_type=unet.config.class_embed_type, + addition_embed_type=addition_embed_type, + addition_time_embed_dim=addition_time_embed_dim, + num_class_embeds=unet.config.num_class_embeds, + upcast_attention=unet.config.upcast_attention, + resnet_time_scale_shift=unet.config.resnet_time_scale_shift, + projection_class_embeddings_input_dim=unet.config.projection_class_embeddings_input_dim, + brushnet_conditioning_channel_order=brushnet_conditioning_channel_order, + conditioning_embedding_out_channels=conditioning_embedding_out_channels, + ) + + if load_weights_from_unet: + conv_in_condition_weight=torch.zeros_like(brushnet.conv_in_condition.weight) + conv_in_condition_weight[:,:4,...]=unet.conv_in.weight + conv_in_condition_weight[:,4:8,...]=unet.conv_in.weight + brushnet.conv_in_condition.weight=torch.nn.Parameter(conv_in_condition_weight) + brushnet.conv_in_condition.bias=unet.conv_in.bias + + brushnet.time_proj.load_state_dict(unet.time_proj.state_dict()) + brushnet.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + + if brushnet.class_embedding: + brushnet.class_embedding.load_state_dict(unet.class_embedding.state_dict()) + + brushnet.down_blocks.load_state_dict(unet.down_blocks.state_dict(),strict=False) + brushnet.mid_block.load_state_dict(unet.mid_block.state_dict(),strict=False) + brushnet.up_blocks.load_state_dict(unet.up_blocks.state_dict(),strict=False) + + return brushnet + + @property + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attention_slice + def set_attention_slice(self, slice_size: Union[str, int, List[int]]) -> None: + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + brushnet_cond: torch.FloatTensor, + conditioning_scale: float = 1.0, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guess_mode: bool = False, + return_dict: bool = True, + ) -> Union[BrushNetOutput, Tuple[Tuple[torch.FloatTensor, ...], torch.FloatTensor]]: + """ + The [`BrushNetModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor. + timestep (`Union[torch.Tensor, float, int]`): + The number of timesteps to denoise an input. + encoder_hidden_states (`torch.Tensor`): + The encoder hidden states. + brushnet_cond (`torch.FloatTensor`): + The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`. + conditioning_scale (`float`, defaults to `1.0`): + The scale factor for BrushNet outputs. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond (`torch.Tensor`, *optional*, defaults to `None`): + Additional conditional embeddings for timestep. If provided, the embeddings will be summed with the + timestep_embedding passed through the `self.time_embedding` layer to obtain the final timestep + embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + added_cond_kwargs (`dict`): + Additional conditions for the Stable Diffusion XL UNet. + cross_attention_kwargs (`dict[str]`, *optional*, defaults to `None`): + A kwargs dictionary that if specified is passed along to the `AttnProcessor`. + guess_mode (`bool`, defaults to `False`): + In this mode, the BrushNet encoder tries its best to recognize the input content of the input even if + you remove all prompts. A `guidance_scale` between 3.0 and 5.0 is recommended. + return_dict (`bool`, defaults to `True`): + Whether or not to return a [`~models.brushnet.BrushNetOutput`] instead of a plain tuple. + + Returns: + [`~models.brushnet.BrushNetOutput`] **or** `tuple`: + If `return_dict` is `True`, a [`~models.brushnet.BrushNetOutput`] is returned, otherwise a tuple is + returned where the first element is the sample tensor. + """ + # check channel order + channel_order = self.config.brushnet_conditioning_channel_order + + if channel_order == "rgb": + # in rgb order by default + ... + elif channel_order == "bgr": + brushnet_cond = torch.flip(brushnet_cond, dims=[1]) + else: + raise ValueError(f"unknown `brushnet_conditioning_channel_order`: {channel_order}") + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + if self.config.addition_embed_type is not None: + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + + elif self.config.addition_embed_type == "text_time": + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + + emb = emb + aug_emb if aug_emb is not None else emb + + # 2. pre-process + brushnet_cond=torch.concat([sample,brushnet_cond],1) + sample = self.conv_in_condition(brushnet_cond) + + + # 3. down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 4. PaintingNet down blocks + brushnet_down_block_res_samples = () + for down_block_res_sample, brushnet_down_block in zip(down_block_res_samples, self.brushnet_down_blocks): + down_block_res_sample = brushnet_down_block(down_block_res_sample) + brushnet_down_block_res_samples = brushnet_down_block_res_samples + (down_block_res_sample,) + + # 5. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + ) + else: + sample = self.mid_block(sample, emb) + + # 6. BrushNet mid blocks + brushnet_mid_block_res_sample = self.brushnet_mid_block(sample) + + + # 7. up + up_block_res_samples = () + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample, up_res_samples = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + return_res_samples=True + ) + else: + sample, up_res_samples = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + return_res_samples=True + ) + + up_block_res_samples += up_res_samples + + # 8. BrushNet up blocks + brushnet_up_block_res_samples = () + for up_block_res_sample, brushnet_up_block in zip(up_block_res_samples, self.brushnet_up_blocks): + up_block_res_sample = brushnet_up_block(up_block_res_sample) + brushnet_up_block_res_samples = brushnet_up_block_res_samples + (up_block_res_sample,) + + # 6. scaling + if guess_mode and not self.config.global_pool_conditions: + scales = torch.logspace(-1, 0, len(brushnet_down_block_res_samples) + 1 + len(brushnet_up_block_res_samples), device=sample.device) # 0.1 to 1.0 + scales = scales * conditioning_scale + + brushnet_down_block_res_samples = [sample * scale for sample, scale in zip(brushnet_down_block_res_samples, scales[:len(brushnet_down_block_res_samples)])] + brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * scales[len(brushnet_down_block_res_samples)] + brushnet_up_block_res_samples = [sample * scale for sample, scale in zip(brushnet_up_block_res_samples, scales[len(brushnet_down_block_res_samples)+1:])] + else: + brushnet_down_block_res_samples = [sample * conditioning_scale for sample in brushnet_down_block_res_samples] + brushnet_mid_block_res_sample = brushnet_mid_block_res_sample * conditioning_scale + brushnet_up_block_res_samples = [sample * conditioning_scale for sample in brushnet_up_block_res_samples] + + + if self.config.global_pool_conditions: + brushnet_down_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_down_block_res_samples + ] + brushnet_mid_block_res_sample = torch.mean(brushnet_mid_block_res_sample, dim=(2, 3), keepdim=True) + brushnet_up_block_res_samples = [ + torch.mean(sample, dim=(2, 3), keepdim=True) for sample in brushnet_up_block_res_samples + ] + + if not return_dict: + return (brushnet_down_block_res_samples, brushnet_mid_block_res_sample, brushnet_up_block_res_samples) + + return BrushNetOutput( + down_block_res_samples=brushnet_down_block_res_samples, + mid_block_res_sample=brushnet_mid_block_res_sample, + up_block_res_samples=brushnet_up_block_res_samples + ) + + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module diff --git a/libs/transformer_temporal.py b/libs/transformer_temporal.py new file mode 100644 index 0000000000000000000000000000000000000000..32928afbb87f903e3b5b75507455b5876cc7a7bd --- /dev/null +++ b/libs/transformer_temporal.py @@ -0,0 +1,375 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput +from diffusers.models.attention import BasicTransformerBlock, TemporalBasicTransformerBlock +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.resnet import AlphaBlender + + +@dataclass +class TransformerTemporalModelOutput(BaseOutput): + """ + The output of [`TransformerTemporalModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size x num_frames, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. + """ + + sample: torch.FloatTensor + + +class TransformerTemporalModel(ModelMixin, ConfigMixin): + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlock` attention should contain a bias parameter. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + activation_fn (`str`, *optional*, defaults to `"geglu"`): + Activation function to use in feed-forward. See `diffusers.models.activations.get_activation` for supported + activation functions. + norm_elementwise_affine (`bool`, *optional*): + Configure if the `TransformerBlock` should use learnable elementwise affine parameters for normalization. + double_self_attention (`bool`, *optional*): + Configure if each `TransformerBlock` should contain two self-attention layers. + positional_embeddings: (`str`, *optional*): + The type of positional embeddings to apply to the sequence input before passing use. + num_positional_embeddings: (`int`, *optional*): + The maximum length of the sequence over which to apply positional embeddings. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + activation_fn: str = "geglu", + norm_elementwise_affine: bool = True, + double_self_attention: bool = True, + positional_embeddings: Optional[str] = None, + num_positional_embeddings: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + double_self_attention=double_self_attention, + norm_elementwise_affine=norm_elementwise_affine, + positional_embeddings=positional_embeddings, + num_positional_embeddings=num_positional_embeddings, + ) + for d in range(num_layers) + ] + ) + + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward( + self, + hidden_states: torch.FloatTensor, + timestep: Optional[torch.LongTensor] = None, + num_frames: int = 1, + encoder_hidden_states: Optional[torch.LongTensor] = None, + class_labels: torch.LongTensor = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> TransformerTemporalModelOutput: + """ + The [`TransformerTemporal`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input hidden_states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + num_frames (`int`, *optional*, defaults to 1): + The number of frames to be processed per batch. This is used to reshape the hidden states. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: + If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is + returned, otherwise a `tuple` where the first element is the sample tensor. + """ + # 1. Input + batch_frames, channel, height, width = hidden_states.shape + batch_size = batch_frames // num_frames + + residual = hidden_states + + hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, channel, height, width) + hidden_states = hidden_states.permute(0, 2, 1, 3, 4) + + hidden_states = self.norm(hidden_states) + hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel) + + hidden_states = self.proj_in(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states[None, None, :] + .reshape(batch_size, height, width, num_frames, channel) + .permute(0, 3, 4, 1, 2) + .contiguous() + ) + hidden_states = hidden_states.reshape(batch_frames, channel, height, width) + + output = hidden_states + residual + + return output + + +class TransformerSpatioTemporalModel(nn.Module): + """ + A Transformer model for video-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + out_channels (`int`, *optional*): + The number of channels in the output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: int = 320, + out_channels: Optional[int] = None, + num_layers: int = 1, + cross_attention_dim: Optional[int] = None, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + + inner_dim = num_attention_heads * attention_head_dim + self.inner_dim = inner_dim + + # 2. Define input layers + self.in_channels = in_channels + self.norm = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6) + self.proj_in = nn.Linear(in_channels, inner_dim) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for d in range(num_layers) + ] + ) + + time_mix_inner_dim = inner_dim + self.temporal_transformer_blocks = nn.ModuleList( + [ + TemporalBasicTransformerBlock( + inner_dim, + time_mix_inner_dim, + num_attention_heads, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + ) + for _ in range(num_layers) + ] + ) + + time_embed_dim = in_channels * 4 + self.time_pos_embed = TimestepEmbedding(in_channels, time_embed_dim, out_dim=in_channels) + self.time_proj = Timesteps(in_channels, True, 0) + self.time_mixer = AlphaBlender(alpha=0.5, merge_strategy="learned_with_images") + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + # TODO: should use out_channels for continuous projections + self.proj_out = nn.Linear(inner_dim, in_channels) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): + """ + Args: + hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`): + Input hidden_states. + num_frames (`int`): + The number of frames to be processed per batch. This is used to reshape the hidden states. + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + image_only_indicator (`torch.LongTensor` of shape `(batch size, num_frames)`, *optional*): + A tensor indicating whether the input contains only images. 1 indicates that the input contains only + images, 0 indicates that the input contains video frames. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.transformer_temporal.TransformerTemporalModelOutput`] instead of a plain + tuple. + + Returns: + [`~models.transformer_temporal.TransformerTemporalModelOutput`] or `tuple`: + If `return_dict` is True, an [`~models.transformer_temporal.TransformerTemporalModelOutput`] is + returned, otherwise a `tuple` where the first element is the sample tensor. + """ + # 1. Input + batch_frames, _, height, width = hidden_states.shape + num_frames = image_only_indicator.shape[-1] + batch_size = batch_frames // num_frames + + time_context = encoder_hidden_states + time_context_first_timestep = time_context[None, :].reshape( + batch_size, num_frames, -1, time_context.shape[-1] + )[:, 0] + time_context = time_context_first_timestep[None, :].broadcast_to( + height * width, batch_size, 1, time_context.shape[-1] + ) + time_context = time_context.reshape(height * width * batch_size, 1, time_context.shape[-1]) + + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch_frames, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + + num_frames_emb = torch.arange(num_frames, device=hidden_states.device) + num_frames_emb = num_frames_emb.repeat(batch_size, 1) + num_frames_emb = num_frames_emb.reshape(-1) + t_emb = self.time_proj(num_frames_emb) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=hidden_states.dtype) + + emb = self.time_pos_embed(t_emb) + emb = emb[:, None, :] + + # 2. Blocks + for block, temporal_block in zip(self.transformer_blocks, self.temporal_transformer_blocks): + if self.training and self.gradient_checkpointing: + hidden_states = torch.utils.checkpoint.checkpoint( + block, + hidden_states, + None, + encoder_hidden_states, + None, + use_reentrant=False, + ) + else: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + ) + + hidden_states_mix = hidden_states + hidden_states_mix = hidden_states_mix + emb + + hidden_states_mix = temporal_block( + hidden_states_mix, + num_frames=num_frames, + encoder_hidden_states=time_context, + ) + hidden_states = self.time_mixer( + x_spatial=hidden_states, + x_temporal=hidden_states_mix, + image_only_indicator=image_only_indicator, + ) + + # 3. Output + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch_frames, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + + if not return_dict: + return (output,) + + return TransformerTemporalModelOutput(sample=output) diff --git a/libs/unet_2d_blocks.py b/libs/unet_2d_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..0b8e6f51f3017b0fb8bd3f59cb05eb23c5f09a80 --- /dev/null +++ b/libs/unet_2d_blocks.py @@ -0,0 +1,3824 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.utils import is_torch_version, logging +from diffusers.utils.torch_utils import apply_freeu +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import Attention, AttnAddedKVProcessor, AttnAddedKVProcessor2_0 +from diffusers.models.normalization import AdaGroupNorm +from diffusers.models.resnet import ( + Downsample2D, + FirDownsample2D, + FirUpsample2D, + KDownsample2D, + KUpsample2D, + ResnetBlock2D, + ResnetBlockCondNorm2D, + Upsample2D, +) +from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel +from diffusers.models.transformers.transformer_2d import Transformer2DModel + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, + resnet_eps: float, + resnet_act_fn: str, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + downsample_padding: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + downsample_type: Optional[str] = None, + dropout: float = 0.0, +): + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock2D": + return DownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "ResnetDownsampleBlock2D": + return ResnetDownsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif down_block_type == "AttnDownBlock2D": + if add_downsample is False: + downsample_type = None + else: + downsample_type = downsample_type or "conv" # default to 'conv' + return AttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + downsample_type=downsample_type, + ) + elif down_block_type == "CrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock2D") + return CrossAttnDownBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + ) + elif down_block_type == "SimpleCrossAttnDownBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnDownBlock2D") + return SimpleCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif down_block_type == "SkipDownBlock2D": + return SkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnSkipDownBlock2D": + return AttnSkipDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "DownEncoderBlock2D": + return DownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "AttnDownEncoderBlock2D": + return AttnDownEncoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "KDownBlock2D": + return KDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif down_block_type == "KCrossAttnDownBlock2D": + return KCrossAttnDownBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + dropout=dropout, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + add_self_attention=True if not add_downsample else False, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_mid_block( + mid_block_type: str, + temb_channels: int, + in_channels: int, + resnet_eps: float, + resnet_act_fn: str, + resnet_groups: int, + output_scale_factor: float = 1.0, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + mid_block_only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = 1, + dropout: float = 0.0, +): + if mid_block_type == "UNetMidBlock2DCrossAttn": + return UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + resnet_groups=resnet_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + return UNetMidBlock2DSimpleCrossAttn( + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type == "UNetMidBlock2D": + return UNetMidBlock2D( + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + num_layers=0, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) + elif mid_block_type == "MidBlock2D": + return MidBlock2D( + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_groups=resnet_groups, + use_linear_projection=use_linear_projection, + ) + elif mid_block_type == "MidBlock2D": + return MidBlock2D( + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + resnet_groups=resnet_groups, + use_linear_projection=use_linear_projection, + ) + elif mid_block_type is None: + return None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + +def get_up_block( + up_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + add_upsample: bool, + resnet_eps: float, + resnet_act_fn: str, + resolution_idx: Optional[int] = None, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = None, + upsample_type: Optional[str] = None, + dropout: float = 0.0, +) -> nn.Module: + # If attn head dim is not defined, we default it to the number of heads + if attention_head_dim is None: + logger.warn( + f"It is recommended to provide `attention_head_dim` when calling `get_up_block`. Defaulting `attention_head_dim` to {num_attention_heads}." + ) + attention_head_dim = num_attention_heads + + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock2D": + return UpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "ResnetUpsampleBlock2D": + return ResnetUpsampleBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + ) + elif up_block_type == "CrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock2D") + return CrossAttnUpBlock2D( + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + ) + elif up_block_type == "SimpleCrossAttnUpBlock2D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for SimpleCrossAttnUpBlock2D") + return SimpleCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + output_scale_factor=resnet_out_scale_factor, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif up_block_type == "AttnUpBlock2D": + if add_upsample is False: + upsample_type = None + else: + upsample_type = upsample_type or "conv" # default to 'conv' + + return AttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + upsample_type=upsample_type, + ) + elif up_block_type == "SkipUpBlock2D": + return SkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "AttnSkipUpBlock2D": + return AttnSkipUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "UpDecoderBlock2D": + return UpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "AttnUpDecoderBlock2D": + return AttnUpDecoderBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + attention_head_dim=attention_head_dim, + resnet_time_scale_shift=resnet_time_scale_shift, + temb_channels=temb_channels, + ) + elif up_block_type == "KUpBlock2D": + return KUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + ) + elif up_block_type == "KCrossAttnUpBlock2D": + return KCrossAttnUpBlock2D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + dropout=dropout, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + ) + + raise ValueError(f"{up_block_type} does not exist.") + + +class AutoencoderTinyBlock(nn.Module): + """ + Tiny Autoencoder block used in [`AutoencoderTiny`]. It is a mini residual module consisting of plain conv + ReLU + blocks. + + Args: + in_channels (`int`): The number of input channels. + out_channels (`int`): The number of output channels. + act_fn (`str`): + ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`. + + Returns: + `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to + `out_channels`. + """ + + def __init__(self, in_channels: int, out_channels: int, act_fn: str): + super().__init__() + act_fn = get_activation(act_fn) + self.conv = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + act_fn, + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + act_fn, + nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1), + ) + self.skip = ( + nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False) + if in_channels != out_channels + else nn.Identity() + ) + self.fuse = nn.ReLU() + + def forward(self, x: torch.FloatTensor) -> torch.FloatTensor: + return self.fuse(self.conv(x) + self.skip(x)) + + +class UNetMidBlock2D(nn.Module): + """ + A 2D UNet mid-block [`UNetMidBlock2D`] with multiple residual blocks and optional attention blocks. + + Args: + in_channels (`int`): The number of input channels. + temb_channels (`int`): The number of temporal embedding channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_time_scale_shift (`str`, *optional*, defaults to `default`): + The type of normalization to apply to the time embeddings. This can help to improve the performance of the + model on tasks with long-range temporal dependencies. + resnet_act_fn (`str`, *optional*, defaults to `swish`): The activation function for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + attn_groups (`Optional[int]`, *optional*, defaults to None): The number of groups for the attention blocks. + resnet_pre_norm (`bool`, *optional*, defaults to `True`): + Whether to use pre-normalization for the resnet blocks. + add_attention (`bool`, *optional*, defaults to `True`): Whether to add attention blocks. + attention_head_dim (`int`, *optional*, defaults to 1): + Dimension of a single attention head. The number of attention heads is determined based on this value and + the number of input channels. + output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + attn_groups: Optional[int] = None, + resnet_pre_norm: bool = True, + add_attention: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + ): + super().__init__() + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + self.add_attention = add_attention + + if attn_groups is None: + attn_groups = resnet_groups if resnet_time_scale_shift == "default" else None + + # there is always at least one resnet + if resnet_time_scale_shift == "spatial": + resnets = [ + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ] + else: + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." + ) + attention_head_dim = in_channels + + for _ in range(num_layers): + if self.add_attention: + attentions.append( + Attention( + in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=attn_groups, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + else: + attentions.append(None) + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor: + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if attn is not None: + hidden_states = attn(hidden_states, temb=temb) + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class UNetMidBlock2DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for i in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + + return hidden_states + + +class UNetMidBlock2DSimpleCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: Optional[str] = None, + ): + super().__init__() + + self.has_cross_attention = True + + self.attention_head_dim = attention_head_dim + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + self.num_heads = in_channels // self.attention_head_dim + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ] + attentions = [] + + for _ in range(num_layers): + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=in_channels, + cross_attention_dim=in_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + lora_scale = cross_attention_kwargs.get("scale", 1.0) + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + # attn + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + # resnet + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + + return hidden_states + + +class MidBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + use_linear_projection: bool = False, + ): + super().__init__() + + self.has_cross_attention = False + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + + for i in range(num_layers): + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + lora_scale = 1.0 + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + for resnet in self.resnets[1:]: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + + return hidden_states + + +class AttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + downsample_type: str = "conv", + ): + super().__init__() + resnets = [] + attentions = [] + self.downsample_type = downsample_type + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if downsample_type == "conv": + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + elif downsample_type == "resnet": + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + lora_scale = cross_attention_kwargs.get("scale", 1.0) + + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + cross_attention_kwargs.update({"scale": lora_scale}) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn(hidden_states, **cross_attention_kwargs) + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + if self.downsample_type == "resnet": + hidden_states = downsampler(hidden_states, temb=temb, scale=lora_scale) + else: + hidden_states = downsampler(hidden_states, scale=lora_scale) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + additional_residuals: Optional[torch.FloatTensor] = None, + down_block_add_samples: Optional[torch.FloatTensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + blocks = list(zip(self.resnets, self.attentions)) + + for i, (resnet, attn) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + + if down_block_add_samples is not None: + hidden_states = hidden_states + down_block_add_samples.pop(0) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale=lora_scale) + + if down_block_add_samples is not None: + hidden_states = hidden_states + down_block_add_samples.pop(0) # todo: add before or after + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class DownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0, + down_block_add_samples: Optional[torch.FloatTensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + + if down_block_add_samples is not None: + hidden_states = hidden_states + down_block_add_samples.pop(0) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale=scale) + + if down_block_add_samples is not None: + hidden_states = hidden_states + down_block_add_samples.pop(0) # todo: add before or after + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class DownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=None, scale=scale) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale) + + return hidden_states + + +class AttnDownEncoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + attentions = [] + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=None, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor: + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=None, scale=scale) + cross_attention_kwargs = {"scale": scale} + hidden_states = attn(hidden_states, **cross_attention_kwargs) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale) + + return hidden_states + + +class AttnSkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = np.sqrt(2.0), + add_downsample: bool = True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + self.attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + skip_sample: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb, scale=scale) + cross_attention_kwargs = {"scale": scale} + hidden_states = attn(hidden_states, **cross_attention_kwargs) + output_states += (hidden_states,) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb, scale=scale) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states,) + + return hidden_states, output_states, skip_sample + + +class SkipDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor: float = np.sqrt(2.0), + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + self.resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(in_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if add_downsample: + self.resnet_down = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + down=True, + kernel="fir", + ) + self.downsamplers = nn.ModuleList([FirDownsample2D(out_channels, out_channels=out_channels)]) + self.skip_conv = nn.Conv2d(3, out_channels, kernel_size=(1, 1), stride=(1, 1)) + else: + self.resnet_down = None + self.downsamplers = None + self.skip_conv = None + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + skip_sample: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]: + output_states = () + + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb, scale) + output_states += (hidden_states,) + + if self.downsamplers is not None: + hidden_states = self.resnet_down(hidden_states, temb, scale) + for downsampler in self.downsamplers: + skip_sample = downsampler(skip_sample) + + hidden_states = self.skip_conv(skip_sample) + hidden_states + + output_states += (hidden_states,) + + return hidden_states, output_states, skip_sample + + +class ResnetDownsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + skip_time_act: bool = False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb, scale) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class SimpleCrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: Optional[str] = None, + ): + super().__init__() + + self.has_cross_attention = True + + resnets = [] + attentions = [] + + self.attention_head_dim = attention_head_dim + self.num_heads = out_channels // self.attention_head_dim + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + down=True, + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + lora_scale = cross_attention_kwargs.get("scale", 1.0) + + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, temb, scale=lora_scale) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class KDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: int = 32, + add_downsample: bool = False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + # YiYi's comments- might be able to use FirDownsample2D, look into details later + self.downsamplers = nn.ModuleList([KDownsample2D()]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, output_states + + +class KCrossAttnDownBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + cross_attention_dim: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_group_size: int = 32, + add_downsample: bool = True, + attention_head_dim: int = 64, + add_self_attention: bool = False, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + dropout=dropout, + temb_channels=temb_channels, + groups=groups, + groups_out=groups_out, + eps=resnet_eps, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + out_channels, + out_channels // attention_head_dim, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm="layer_norm", + group_size=resnet_group_size, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList([KDownsample2D()]) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if self.downsamplers is None: + output_states += (None,) + else: + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + return hidden_states, output_states + + +class AttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: int = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + upsample_type: str = "conv", + ): + super().__init__() + resnets = [] + attentions = [] + + self.upsample_type = upsample_type + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if upsample_type == "conv": + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + elif upsample_type == "resnet": + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb, scale=scale) + cross_attention_kwargs = {"scale": scale} + hidden_states = attn(hidden_states, **cross_attention_kwargs) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + if self.upsample_type == "resnet": + hidden_states = upsampler(hidden_states, temb=temb, scale=scale) + else: + hidden_states = upsampler(hidden_states, scale=scale) + + return hidden_states + + +class CrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + return_res_samples: Optional[bool]=False, + up_block_add_samples: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + if return_res_samples: + output_states=() + + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + if return_res_samples: + output_states = output_states + (hidden_states,) + if up_block_add_samples is not None: + hidden_states = hidden_states + up_block_add_samples.pop(0) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) + if return_res_samples: + output_states = output_states + (hidden_states,) + if up_block_add_samples is not None: + hidden_states = hidden_states + up_block_add_samples.pop(0) + + if return_res_samples: + return hidden_states, output_states + else: + return hidden_states + +class UpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + scale: float = 1.0, + return_res_samples: Optional[bool]=False, + up_block_add_samples: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + if return_res_samples: + output_states = () + + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + + if return_res_samples: + output_states = output_states + (hidden_states,) + if up_block_add_samples is not None: + hidden_states = hidden_states + up_block_add_samples.pop(0) # todo: add before or after + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size, scale=scale) + + if return_res_samples: + output_states = output_states + (hidden_states,) + if up_block_add_samples is not None: + hidden_states = hidden_states + up_block_add_samples.pop(0) # todo: add before or after + + + if return_res_samples: + return hidden_states, output_states + else: + return hidden_states + + +class UpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", # default, spatial + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: Optional[int] = None, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + ) -> torch.FloatTensor: + for resnet in self.resnets: + hidden_states = resnet(hidden_states, temb=temb, scale=scale) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class AttnUpDecoderBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temb_channels: Optional[int] = None, + ): + super().__init__() + resnets = [] + attentions = [] + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + if resnet_time_scale_shift == "spatial": + resnets.append( + ResnetBlockCondNorm2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm="spatial", + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + ) + ) + else: + resnets.append( + ResnetBlock2D( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=resnet_groups if resnet_time_scale_shift != "spatial" else None, + spatial_norm_dim=temb_channels if resnet_time_scale_shift == "spatial" else None, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.resolution_idx = resolution_idx + + def forward( + self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0 + ) -> torch.FloatTensor: + for resnet, attn in zip(self.resnets, self.attentions): + hidden_states = resnet(hidden_states, temb=temb, scale=scale) + cross_attention_kwargs = {"scale": scale} + hidden_states = attn(hidden_states, temb=temb, **cross_attention_kwargs) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, scale=scale) + + return hidden_states + + +class AttnSkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + output_scale_factor: float = np.sqrt(2.0), + add_upsample: bool = True, + ): + super().__init__() + self.attentions = nn.ModuleList([]) + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(resnet_in_channels + res_skip_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if attention_head_dim is None: + logger.warn( + f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `out_channels`: {out_channels}." + ) + attention_head_dim = out_channels + + self.attentions.append( + Attention( + out_channels, + heads=out_channels // attention_head_dim, + dim_head=attention_head_dim, + rescale_output_factor=output_scale_factor, + eps=resnet_eps, + norm_num_groups=32, + residual_connection=True, + bias=True, + upcast_softmax=True, + _from_deprecated_attn_block=True, + ) + ) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + skip_sample=None, + scale: float = 1.0, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb, scale=scale) + + cross_attention_kwargs = {"scale": scale} + hidden_states = self.attentions[0](hidden_states, **cross_attention_kwargs) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb, scale=scale) + + return hidden_states, skip_sample + + +class SkipUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_pre_norm: bool = True, + output_scale_factor: float = np.sqrt(2.0), + add_upsample: bool = True, + upsample_padding: int = 1, + ): + super().__init__() + self.resnets = nn.ModuleList([]) + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + self.resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min((resnet_in_channels + res_skip_channels) // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.upsampler = FirUpsample2D(in_channels, out_channels=out_channels) + if add_upsample: + self.resnet_up = ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=min(out_channels // 4, 32), + groups_out=min(out_channels // 4, 32), + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + use_in_shortcut=True, + up=True, + kernel="fir", + ) + self.skip_conv = nn.Conv2d(out_channels, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) + self.skip_norm = torch.nn.GroupNorm( + num_groups=min(out_channels // 4, 32), num_channels=out_channels, eps=resnet_eps, affine=True + ) + self.act = nn.SiLU() + else: + self.resnet_up = None + self.skip_conv = None + self.skip_norm = None + self.act = None + + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + skip_sample=None, + scale: float = 1.0, + ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb, scale=scale) + + if skip_sample is not None: + skip_sample = self.upsampler(skip_sample) + else: + skip_sample = 0 + + if self.resnet_up is not None: + skip_sample_states = self.skip_norm(hidden_states) + skip_sample_states = self.act(skip_sample_states) + skip_sample_states = self.skip_conv(skip_sample_states) + + skip_sample = skip_sample + skip_sample_states + + hidden_states = self.resnet_up(hidden_states, temb, scale=scale) + + return hidden_states, skip_sample + + +class ResnetUpsampleBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + skip_time_act: bool = False, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb, scale=scale) + + return hidden_states + + +class SimpleCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attention_head_dim: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + skip_time_act: bool = False, + only_cross_attention: bool = False, + cross_attention_norm: Optional[str] = None, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attention_head_dim = attention_head_dim + + self.num_heads = out_channels // self.attention_head_dim + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + ) + ) + + processor = ( + AttnAddedKVProcessor2_0() if hasattr(F, "scaled_dot_product_attention") else AttnAddedKVProcessor() + ) + + attentions.append( + Attention( + query_dim=out_channels, + cross_attention_dim=out_channels, + heads=self.num_heads, + dim_head=self.attention_head_dim, + added_kv_proj_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + bias=True, + upcast_softmax=True, + only_cross_attention=only_cross_attention, + cross_attention_norm=cross_attention_norm, + processor=processor, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList( + [ + ResnetBlock2D( + in_channels=out_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + skip_time_act=skip_time_act, + up=True, + ) + ] + ) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + lora_scale = cross_attention_kwargs.get("scale", 1.0) + if attention_mask is None: + # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask. + mask = None if encoder_hidden_states is None else encoder_attention_mask + else: + # when attention_mask is defined: we don't even check for encoder_attention_mask. + # this is to maintain compatibility with UnCLIP, which uses 'attention_mask' param for cross-attn masks. + # TODO: UnCLIP should express cross-attn mask via encoder_attention_mask param instead of via attention_mask. + # then we can simplify this whole if/else block to: + # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask + mask = attention_mask + + for resnet, attn in zip(self.resnets, self.attentions): + # resnet + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=mask, + **cross_attention_kwargs, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, temb, scale=lora_scale) + + return hidden_states + + +class KUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + resolution_idx: int, + dropout: float = 0.0, + num_layers: int = 5, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: Optional[int] = 32, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + k_in_channels = 2 * out_channels + k_out_channels = in_channels + num_layers = num_layers - 1 + + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=k_out_channels if (i == num_layers - 1) else out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + scale: float = 1.0, + ) -> torch.FloatTensor: + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb, use_reentrant=False + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class KCrossAttnUpBlock2D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + resolution_idx: int, + dropout: float = 0.0, + num_layers: int = 4, + resnet_eps: float = 1e-5, + resnet_act_fn: str = "gelu", + resnet_group_size: int = 32, + attention_head_dim: int = 1, # attention dim_head + cross_attention_dim: int = 768, + add_upsample: bool = True, + upcast_attention: bool = False, + ): + super().__init__() + resnets = [] + attentions = [] + + is_first_block = in_channels == out_channels == temb_channels + is_middle_block = in_channels != out_channels + add_self_attention = True if is_first_block else False + + self.has_cross_attention = True + self.attention_head_dim = attention_head_dim + + # in_channels, and out_channels for the block (k-unet) + k_in_channels = out_channels if is_first_block else 2 * out_channels + k_out_channels = in_channels + + num_layers = num_layers - 1 + + for i in range(num_layers): + in_channels = k_in_channels if i == 0 else out_channels + groups = in_channels // resnet_group_size + groups_out = out_channels // resnet_group_size + + if is_middle_block and (i == num_layers - 1): + conv_2d_out_channels = k_out_channels + else: + conv_2d_out_channels = None + + resnets.append( + ResnetBlockCondNorm2D( + in_channels=in_channels, + out_channels=out_channels, + conv_2d_out_channels=conv_2d_out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=groups, + groups_out=groups_out, + dropout=dropout, + non_linearity=resnet_act_fn, + time_embedding_norm="ada_group", + conv_shortcut_bias=False, + ) + ) + attentions.append( + KAttentionBlock( + k_out_channels if (i == num_layers - 1) else out_channels, + k_out_channels // attention_head_dim + if (i == num_layers - 1) + else out_channels // attention_head_dim, + attention_head_dim, + cross_attention_dim=cross_attention_dim, + temb_channels=temb_channels, + attention_bias=True, + add_self_attention=add_self_attention, + cross_attention_norm="layer_norm", + upcast_attention=upcast_attention, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.attentions = nn.ModuleList(attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([KUpsample2D()]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + res_hidden_states_tuple = res_hidden_states_tuple[-1] + if res_hidden_states_tuple is not None: + hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1) + + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + emb=temb, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +# can potentially later be renamed to `No-feed-forward` attention +class KAttentionBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + attention_bias (`bool`, *optional*, defaults to `False`): + Configure if the attention layers should contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to upcast the attention computation to `float32`. + temb_channels (`int`, *optional*, defaults to 768): + The number of channels in the token embedding. + add_self_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to add self-attention to the block. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + group_size (`int`, *optional*, defaults to 32): + The number of groups to separate the channels into for group normalization. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout: float = 0.0, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + upcast_attention: bool = False, + temb_channels: int = 768, # for ada_group_norm + add_self_attention: bool = False, + cross_attention_norm: Optional[str] = None, + group_size: int = 32, + ): + super().__init__() + self.add_self_attention = add_self_attention + + # 1. Self-Attn + if add_self_attention: + self.norm1 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + cross_attention_norm=None, + ) + + # 2. Cross-Attn + self.norm2 = AdaGroupNorm(temb_channels, dim, max(1, dim // group_size)) + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + cross_attention_norm=cross_attention_norm, + ) + + def _to_3d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor: + return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1) + + def _to_4d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor: + return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight) + + def forward( + self, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + # TODO: mark emb as non-optional (self.norm2 requires it). + # requires assessing impact of change to positional param interface. + emb: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {} + + # 1. Self-Attention + if self.add_self_attention: + norm_hidden_states = self.norm1(hidden_states, emb) + + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + + attn_output = self.attn1( + norm_hidden_states, + encoder_hidden_states=None, + attention_mask=attention_mask, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + + hidden_states = attn_output + hidden_states + + # 2. Cross-Attention/None + norm_hidden_states = self.norm2(hidden_states, emb) + + height, weight = norm_hidden_states.shape[2:] + norm_hidden_states = self._to_3d(norm_hidden_states, height, weight) + attn_output = self.attn2( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask if encoder_hidden_states is None else encoder_attention_mask, + **cross_attention_kwargs, + ) + attn_output = self._to_4d(attn_output, height, weight) + + hidden_states = attn_output + hidden_states + + return hidden_states diff --git a/libs/unet_2d_condition.py b/libs/unet_2d_condition.py new file mode 100644 index 0000000000000000000000000000000000000000..a8dffb7f9c945e8f3ebd77e2b91eab3b85e08dd0 --- /dev/null +++ b/libs/unet_2d_condition.py @@ -0,0 +1,1359 @@ +# Copyright 2024 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers +from diffusers.models.activations import get_activation +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + Attention, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +from diffusers.models.embeddings import ( + GaussianFourierProjection, + GLIGENTextBoundingboxProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.modeling_utils import ModelMixin +from .unet_2d_blocks import ( + get_down_block, + get_mid_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + + +class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling + blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, + *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, + *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + dropout: float = 0.0, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: float = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads: int = 64, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + self._check_config( + down_block_types=down_block_types, + up_block_types=up_block_types, + only_cross_attention=only_cross_attention, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + cross_attention_dim=cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block, + reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + ) + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim, timestep_input_dim = self._set_time_proj( + time_embedding_type, + block_out_channels=block_out_channels, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + time_embedding_dim=time_embedding_dim, + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + self._set_encoder_hid_proj( + encoder_hid_dim_type, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + ) + + # class embedding + self._set_class_embedding( + class_embed_type, + act_fn=act_fn, + num_class_embeds=num_class_embeds, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + time_embed_dim=time_embed_dim, + timestep_input_dim=timestep_input_dim, + ) + + self._set_add_embedding( + addition_embed_type, + addition_embed_type_num_heads=addition_embed_type_num_heads, + addition_time_embed_dim=addition_time_embed_dim, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + time_embed_dim=time_embed_dim, + ) + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + self.mid_block = get_mid_block( + mid_block_type, + temb_channels=blocks_time_embed_dim, + in_channels=block_out_channels[-1], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + output_scale_factor=mid_block_scale_factor, + transformer_layers_per_block=transformer_layers_per_block[-1], + num_attention_heads=num_attention_heads[-1], + cross_attention_dim=cross_attention_dim[-1], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + mid_block_only_cross_attention=mid_block_only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[-1], + dropout=dropout, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resolution_idx=i, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) + + def _check_config( + self, + down_block_types: Tuple[str], + up_block_types: Tuple[str], + only_cross_attention: Union[bool, Tuple[bool]], + block_out_channels: Tuple[int], + layers_per_block: Union[int, Tuple[int]], + cross_attention_dim: Union[int, Tuple[int]], + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple[int]]], + reverse_transformer_layers_per_block: bool, + attention_head_dim: int, + num_attention_heads: Optional[Union[int, Tuple[int]]], + ): + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + + def _set_time_proj( + self, + time_embedding_type: str, + block_out_channels: int, + flip_sin_to_cos: bool, + freq_shift: float, + time_embedding_dim: int, + ) -> Tuple[int, int]: + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + return time_embed_dim, timestep_input_dim + + def _set_encoder_hid_proj( + self, + encoder_hid_dim_type: Optional[str], + cross_attention_dim: Union[int, Tuple[int]], + encoder_hid_dim: Optional[int], + ): + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + def _set_class_embedding( + self, + class_embed_type: Optional[str], + act_fn: str, + num_class_embeds: Optional[int], + projection_class_embeddings_input_dim: Optional[int], + time_embed_dim: int, + timestep_input_dim: int, + ): + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + def _set_add_embedding( + self, + addition_embed_type: str, + addition_embed_type_num_heads: int, + addition_time_embed_dim: Optional[int], + flip_sin_to_cos: bool, + freq_shift: float, + cross_attention_dim: Optional[int], + encoder_hid_dim: Optional[int], + projection_class_embeddings_input_dim: Optional[int], + time_embed_dim: int, + ): + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int): + if attention_type in ["gated", "gated-text-image"]: + positive_len = 768 + if isinstance(cross_attention_dim, int): + positive_len = cross_attention_dim + elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): + positive_len = cross_attention_dim[0] + + feature_type = "text-only" if attention_type == "gated" else "text-image" + self.position_net = GLIGENTextBoundingboxProjection( + positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type + ) + + @property + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor) + else: + module.set_processor(processor.pop(f"{name}.processor")) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + def set_default_attn_processor(self): + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor) + + def set_attention_slice(self, slice_size: Union[str, int, List[int]] = "auto"): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module splits the input tensor in slices to compute attention in + several steps. This is useful for saving some memory in exchange for a small decrease in speed. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, input to the attention heads is halved, so attention is computed in two steps. If + `"max"`, maximum amount of memory is saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_sliceable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_sliceable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_sliceable_dims(module) + + num_sliceable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_sliceable_layers * [1] + + slice_size = num_sliceable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float): + r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + def disable_freeu(self): + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + setattr(upsample_block, k, None) + + def fuse_qkv_projections(self): + """ + Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, + key, value) are fused. For cross-attention modules, key and value projection matrices are fused. + + + + This API is 🧪 experimental. + + + """ + self.original_attn_processors = None + + for _, attn_processor in self.attn_processors.items(): + if "Added" in str(attn_processor.__class__.__name__): + raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.") + + self.original_attn_processors = self.attn_processors + + for module in self.modules(): + if isinstance(module, Attention): + module.fuse_projections(fuse=True) + + def unfuse_qkv_projections(self): + """Disables the fused QKV projection if enabled. + + + + This API is 🧪 experimental. + + + + """ + if self.original_attn_processors is not None: + self.set_attn_processor(self.original_attn_processors) + + def unload_lora(self): + """Unloads LoRA weights.""" + deprecate( + "unload_lora", + "0.28.0", + "Calling `unload_lora()` is deprecated and will be removed in a future version. Please install `peft` and then call `disable_adapters().", + ) + for module in self.modules(): + if hasattr(module, "set_lora_layer"): + module.set_lora_layer(None) + + def get_time_embed( + self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int] + ) -> Optional[torch.Tensor]: + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + return t_emb + + def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + class_emb = None + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + return class_emb + + def get_aug_embed( + self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] + ) -> Optional[torch.Tensor]: + aug_emb = None + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb = self.add_embedding(image_embs, hint) + return aug_emb + + def process_encoder_hidden_states( + self, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict[str, Any] + ) -> torch.Tensor: + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds) + encoder_hidden_states = (encoder_hidden_states, image_embeds) + return encoder_hidden_states + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + down_block_add_samples: Optional[Tuple[torch.Tensor]] = None, + mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None, + up_block_add_samples: Optional[Tuple[torch.Tensor]] = None, + features_adapter: Optional[torch.Tensor] = None, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added to UNet long skip connections from down blocks to up blocks for + example from ControlNet side model(s) + mid_block_additional_residual (`torch.Tensor`, *optional*): + additional residual to be added to UNet mid block output, for example from ControlNet side model + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + features_adapter (`torch.FloatTensor`, *optional*): + (batch, channels, num_frames, height, width) adapter features tensor + + Returns: + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + t_emb = self.get_time_embed(sample=sample, timestep=timestep) + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) + if class_emb is not None: + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + aug_emb = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + if self.config.addition_embed_type == "image_hint": + aug_emb, hint = aug_emb + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + encoder_hidden_states = self.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + + # 2. pre-process + sample = self.conv_in(sample) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + is_brushnet = down_block_add_samples is not None and mid_block_add_sample is not None and up_block_add_samples is not None + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + + if is_brushnet: + sample = sample + down_block_add_samples.pop(0) + + adapter_idx = 0 + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + + if is_brushnet and len(down_block_add_samples)>0: + additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0) + for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))] + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + additional_residuals = {} + if is_brushnet and len(down_block_add_samples)>0: + additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0) + for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))] + + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale, **additional_residuals) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + if features_adapter is not None: + sample += features_adapter[adapter_idx] + adapter_idx += 1 + + down_block_res_samples += res_samples + + if features_adapter is not None: + assert len(features_adapter) == adapter_idx, "Wrong features_adapter" + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + if is_brushnet: + sample = sample + mid_block_add_sample + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + additional_residuals = {} + if is_brushnet and len(up_block_add_samples)>0: + additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0) + for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))] + + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + additional_residuals = {} + if is_brushnet and len(up_block_add_samples)>0: + additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0) + for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))] + + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=lora_scale, + **additional_residuals, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) diff --git a/libs/unet_3d_blocks.py b/libs/unet_3d_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..1b6ce7413919d80a225a1f86a3959c5c022ed411 --- /dev/null +++ b/libs/unet_3d_blocks.py @@ -0,0 +1,2463 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Dict, Optional, Tuple, Union + +import torch +from torch import nn + +from diffusers.utils import is_torch_version +from diffusers.utils.torch_utils import apply_freeu +from diffusers.models.attention import Attention +from diffusers.models.transformers.dual_transformer_2d import DualTransformer2DModel +from diffusers.models.resnet import ( + Downsample2D, + ResnetBlock2D, + SpatioTemporalResBlock, + TemporalConvLayer, + Upsample2D, +) +from diffusers.models.transformers.transformer_2d import Transformer2DModel +from diffusers.models.transformers.transformer_temporal import ( + TransformerSpatioTemporalModel, +) +from libs.transformer_temporal import TransformerTemporalModel + +def get_down_block( + down_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + temb_channels: int, + add_downsample: bool, + resnet_eps: float, + resnet_act_fn: str, + num_attention_heads: int, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + downsample_padding: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = True, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + transformer_layers_per_block: int = 1, +) -> Union[ + "DownBlock3D", + "CrossAttnDownBlock3D", + "DownBlockMotion", + "CrossAttnDownBlockMotion", + "DownBlockSpatioTemporal", + "CrossAttnDownBlockSpatioTemporal", +]: + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + if down_block_type == "DownBlockMotion": + return DownBlockMotion( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + temporal_num_attention_heads=temporal_num_attention_heads, + temporal_max_seq_length=temporal_max_seq_length, + ) + elif down_block_type == "CrossAttnDownBlockMotion": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion") + return CrossAttnDownBlockMotion( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + temporal_num_attention_heads=temporal_num_attention_heads, + temporal_max_seq_length=temporal_max_seq_length, + ) + elif down_block_type == "DownBlockSpatioTemporal": + # added for SDV + return DownBlockSpatioTemporal( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + ) + elif down_block_type == "CrossAttnDownBlockSpatioTemporal": + # added for SDV + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockSpatioTemporal") + return CrossAttnDownBlockSpatioTemporal( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + add_downsample=add_downsample, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + ) + + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type: str, + num_layers: int, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + add_upsample: bool, + resnet_eps: float, + resnet_act_fn: str, + num_attention_heads: int, + resolution_idx: Optional[int] = None, + resnet_groups: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = True, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + temporal_num_attention_heads: int = 8, + temporal_cross_attention_dim: Optional[int] = None, + temporal_max_seq_length: int = 32, + transformer_layers_per_block: int = 1, + dropout: float = 0.0, +) -> Union[ + "UpBlock3D", + "CrossAttnUpBlock3D", + "UpBlockMotion", + "CrossAttnUpBlockMotion", + "UpBlockSpatioTemporal", + "CrossAttnUpBlockSpatioTemporal", +]: + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, + ) + if up_block_type == "UpBlockMotion": + return UpBlockMotion( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, + temporal_num_attention_heads=temporal_num_attention_heads, + temporal_max_seq_length=temporal_max_seq_length, + ) + elif up_block_type == "CrossAttnUpBlockMotion": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion") + return CrossAttnUpBlockMotion( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + resolution_idx=resolution_idx, + temporal_num_attention_heads=temporal_num_attention_heads, + temporal_max_seq_length=temporal_max_seq_length, + ) + elif up_block_type == "UpBlockSpatioTemporal": + # added for SDV + return UpBlockSpatioTemporal( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + resolution_idx=resolution_idx, + add_upsample=add_upsample, + ) + elif up_block_type == "CrossAttnUpBlockSpatioTemporal": + # added for SDV + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockSpatioTemporal") + return CrossAttnUpBlockSpatioTemporal( + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + num_layers=num_layers, + transformer_layers_per_block=transformer_layers_per_block, + add_upsample=add_upsample, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + resolution_idx=resolution_idx, + ) + + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: bool = False, + use_linear_projection: bool = True, + upcast_attention: bool = False, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + temp_convs = [ + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ] + attentions = [] + temp_attentions = [] + + for _ in range(num_layers): + attentions.append( + Transformer2DModel( + in_channels // num_attention_heads, + num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + in_channels // num_attention_heads, + num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + in_channels, + in_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ) -> torch.FloatTensor: + hidden_states = self.resnets[0](hidden_states, temb) + hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames) + for attn, temp_attn, resnet, temp_conv in zip( + self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:] + ): + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + ): + super().__init__() + resnets = [] + attentions = [] + temp_attentions = [] + temp_convs = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + attentions.append( + Transformer2DModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Dict[str, Any] = None, + ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + # TODO(Patrick, William) - attention mask is not used + output_states = () + + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + resolution_idx: Optional[int] = None, + ): + super().__init__() + resnets = [] + temp_convs = [] + attentions = [] + temp_attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + attentions.append( + Transformer2DModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + temp_attentions.append( + TransformerTemporalModel( + out_channels // num_attention_heads, + num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + self.attentions = nn.ModuleList(attentions) + self.temp_attentions = nn.ModuleList(temp_attentions) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + cross_attention_kwargs: Dict[str, Any] = None, + ) -> torch.FloatTensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + # TODO(Patrick, William) - attention mask is not used + for resnet, temp_conv, attn, temp_attn in zip( + self.resnets, self.temp_convs, self.attentions, self.temp_attentions + ): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + hidden_states = temp_attn( + hidden_states, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + resolution_idx: Optional[int] = None, + ): + super().__init__() + resnets = [] + temp_convs = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + temp_convs.append( + TemporalConvLayer( + out_channels, + out_channels, + dropout=0.1, + norm_num_groups=resnet_groups, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.temp_convs = nn.ModuleList(temp_convs) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size: Optional[int] = None, + num_frames: int = 1, + ) -> torch.FloatTensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + for resnet, temp_conv in zip(self.resnets, self.temp_convs): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + hidden_states = resnet(hidden_states, temb) + hidden_states = temp_conv(hidden_states, num_frames=num_frames) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class DownBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_downsample: bool = True, + downsample_padding: int = 1, + temporal_num_attention_heads: int = 1, + temporal_cross_attention_dim: Optional[int] = None, + temporal_max_seq_length: int = 32, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + TransformerTemporalModel( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + down_block_add_samples: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + scale: float = 1.0, + num_frames: int = 1, + ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + + blocks = zip(self.resnets, self.motion_modules) + for resnet, motion_module in blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + + if down_block_add_samples is not None: + hidden_states = hidden_states + down_block_add_samples.pop(0) + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + num_frames, + **ckpt_kwargs, + ) + + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + if down_block_add_samples is not None: + hidden_states = hidden_states + down_block_add_samples.pop(0) + hidden_states = motion_module(hidden_states, num_frames=num_frames) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale=scale) + + if down_block_add_samples is not None: + hidden_states = hidden_states + down_block_add_samples.pop(0) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + downsample_padding: int = 1, + add_downsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + + motion_modules.append( + TransformerTemporalModel( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=downsample_padding, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + additional_residuals: Optional[torch.FloatTensor] = None, + down_block_add_samples: Optional[torch.FloatTensor] = None, + ): + output_states = () + + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + blocks = list(zip(self.resnets, self.attentions, self.motion_modules)) + for i, (resnet, attn, motion_module) in enumerate(blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + if down_block_add_samples is not None: + hidden_states = hidden_states + down_block_add_samples.pop(0) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + num_frames, + **ckpt_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + # apply additional residuals to the output of the last pair of resnet and attention blocks + if i == len(blocks) - 1 and additional_residuals is not None: + hidden_states = hidden_states + additional_residuals + if down_block_add_samples is not None: + hidden_states = hidden_states + down_block_add_samples.pop(0) + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + ) + + # # apply additional residuals to the output of the last pair of resnet and attention blocks + # if i == len(blocks) - 1 and additional_residuals is not None: + # hidden_states = hidden_states + additional_residuals + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states, scale=lora_scale) + + if down_block_add_samples is not None: + hidden_states = hidden_states + down_block_add_samples.pop(0) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + attention_type: str = "default", + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + motion_modules.append( + TransformerTemporalModel( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + upsample_size: Optional[int] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + up_block_add_samples: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + blocks = zip(self.resnets, self.attentions, self.motion_modules) + for resnet, attn, motion_module in blocks: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + if up_block_add_samples is not None: + hidden_states = hidden_states + up_block_add_samples.pop(0) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + num_frames, + **ckpt_kwargs, + ) + else: + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + if up_block_add_samples is not None: + hidden_states = hidden_states + up_block_add_samples.pop(0) + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale) + if up_block_add_samples is not None: + hidden_states = hidden_states + up_block_add_samples.pop(0) + + return hidden_states + + +class UpBlockMotion(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor: float = 1.0, + add_upsample: bool = True, + temporal_norm_num_groups: int = 32, + temporal_cross_attention_dim: Optional[int] = None, + temporal_num_attention_heads: int = 8, + temporal_max_seq_length: int = 32, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock2D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + motion_modules.append( + TransformerTemporalModel( + num_attention_heads=temporal_num_attention_heads, + in_channels=out_channels, + norm_num_groups=temporal_norm_num_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + activation_fn="geglu", + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + attention_head_dim=out_channels // temporal_num_attention_heads, + ) + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + upsample_size=None, + scale: float = 1.0, + num_frames: int = 1, + up_block_add_samples: Optional[torch.FloatTensor] = None, + ) -> torch.FloatTensor: + is_freeu_enabled = ( + getattr(self, "s1", None) + and getattr(self, "s2", None) + and getattr(self, "b1", None) + and getattr(self, "b2", None) + ) + + blocks = zip(self.resnets, self.motion_modules) + + for resnet, motion_module in blocks: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + # FreeU: Only operate on the first two stages + if is_freeu_enabled: + hidden_states, res_hidden_states = apply_freeu( + self.resolution_idx, + hidden_states, + res_hidden_states, + s1=self.s1, + s2=self.s2, + b1=self.b1, + b2=self.b2, + ) + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), hidden_states, temb + ) + + if up_block_add_samples is not None: + hidden_states = hidden_states + up_block_add_samples.pop(0) + + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + num_frames, + use_reentrant=False, + ) + else: + hidden_states = resnet(hidden_states, temb, scale=scale) + if up_block_add_samples is not None: + hidden_states = hidden_states + up_block_add_samples.pop(0) + hidden_states = motion_module(hidden_states, num_frames=num_frames) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size, scale=scale) + + if up_block_add_samples is not None: + hidden_states = hidden_states + up_block_add_samples.pop(0) + + return hidden_states + + +class UNetMidBlockCrossAttnMotion(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + transformer_layers_per_block: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + num_attention_heads: int = 1, + output_scale_factor: float = 1.0, + cross_attention_dim: int = 1280, + dual_cross_attention: float = False, + use_linear_projection: float = False, + upcast_attention: float = False, + attention_type: str = "default", + temporal_num_attention_heads: int = 1, + temporal_cross_attention_dim: Optional[int] = None, + temporal_max_seq_length: int = 32, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + motion_modules = [] + + for _ in range(num_layers): + if not dual_cross_attention: + attentions.append( + Transformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + ) + else: + attentions.append( + DualTransformer2DModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + ) + ) + resnets.append( + ResnetBlock2D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + TransformerTemporalModel( + num_attention_heads=temporal_num_attention_heads, + attention_head_dim=in_channels // temporal_num_attention_heads, + in_channels=in_channels, + norm_num_groups=resnet_groups, + cross_attention_dim=temporal_cross_attention_dim, + attention_bias=False, + positional_embeddings="sinusoidal", + num_positional_embeddings=temporal_max_seq_length, + activation_fn="geglu", + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + num_frames: int = 1, + mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None, + ) -> torch.FloatTensor: + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale) + + blocks = zip(self.attentions, self.resnets[1:], self.motion_modules) + for attn, resnet, motion_module in blocks: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + ########## + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + if mid_block_add_sample is not None: + hidden_states = hidden_states + mid_block_add_sample + ################################################################ + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(motion_module), + hidden_states.requires_grad_(), + temb, + num_frames, + **ckpt_kwargs, + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + return_dict=False, + )[0] + ########## + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + if mid_block_add_sample is not None: + hidden_states = hidden_states + mid_block_add_sample + ################################################################ + hidden_states = motion_module( + hidden_states, + num_frames=num_frames, + ) + hidden_states = resnet(hidden_states, temb, scale=lora_scale) + + return hidden_states + + +class MidBlockTemporalDecoder(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + attention_head_dim: int = 512, + num_layers: int = 1, + upcast_attention: bool = False, + ): + super().__init__() + + resnets = [] + attentions = [] + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + resnets.append( + SpatioTemporalResBlock( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=1e-6, + temporal_eps=1e-5, + merge_factor=0.0, + merge_strategy="learned", + switch_spatial_to_temporal_mix=True, + ) + ) + + attentions.append( + Attention( + query_dim=in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + eps=1e-6, + upcast_attention=upcast_attention, + norm_num_groups=32, + bias=True, + residual_connection=True, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward( + self, + hidden_states: torch.FloatTensor, + image_only_indicator: torch.FloatTensor, + ): + hidden_states = self.resnets[0]( + hidden_states, + image_only_indicator=image_only_indicator, + ) + for resnet, attn in zip(self.resnets[1:], self.attentions): + hidden_states = attn(hidden_states) + hidden_states = resnet( + hidden_states, + image_only_indicator=image_only_indicator, + ) + + return hidden_states + + +class UpBlockTemporalDecoder(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + num_layers: int = 1, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + resnets.append( + SpatioTemporalResBlock( + in_channels=input_channels, + out_channels=out_channels, + temb_channels=None, + eps=1e-6, + temporal_eps=1e-5, + merge_factor=0.0, + merge_strategy="learned", + switch_spatial_to_temporal_mix=True, + ) + ) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + def forward( + self, + hidden_states: torch.FloatTensor, + image_only_indicator: torch.FloatTensor, + ) -> torch.FloatTensor: + for resnet in self.resnets: + hidden_states = resnet( + hidden_states, + image_only_indicator=image_only_indicator, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class UNetMidBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + ): + super().__init__() + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + # support for variable transformer layers per block + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + # there is always at least one resnet + resnets = [ + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=1e-5, + ) + ] + attentions = [] + + for i in range(num_layers): + attentions.append( + TransformerSpatioTemporalModel( + num_attention_heads, + in_channels // num_attention_heads, + in_channels=in_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + ) + ) + + resnets.append( + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=1e-5, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + hidden_states = self.resnets[0]( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + + for attn, resnet in zip(self.attentions, self.resnets[1:]): + if self.training and self.gradient_checkpointing: # TODO + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + **ckpt_kwargs, + ) + else: + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + + return hidden_states + + +class DownBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + num_layers: int = 1, + add_downsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=1e-5, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + ) + else: + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnDownBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + add_downsample: bool = True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + SpatioTemporalResBlock( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=1e-6, + ) + ) + attentions.append( + TransformerSpatioTemporalModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample2D( + out_channels, + use_conv=True, + out_channels=out_channels, + padding=1, + name="op", + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states: torch.FloatTensor, + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]: + output_states = () + + blocks = list(zip(self.resnets, self.attentions)) + for resnet, attn in blocks: + if self.training and self.gradient_checkpointing: # TODO + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + **ckpt_kwargs, + ) + + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + else: + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + + output_states = output_states + (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states = output_states + (hidden_states,) + + return hidden_states, output_states + + +class UpBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + num_layers: int = 1, + resnet_eps: float = 1e-6, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SpatioTemporalResBlock( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + if is_torch_version(">=", "1.11.0"): + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + use_reentrant=False, + ) + else: + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + ) + else: + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states + + +class CrossAttnUpBlockSpatioTemporal(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + resolution_idx: Optional[int] = None, + num_layers: int = 1, + transformer_layers_per_block: Union[int, Tuple[int]] = 1, + resnet_eps: float = 1e-6, + num_attention_heads: int = 1, + cross_attention_dim: int = 1280, + add_upsample: bool = True, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.num_attention_heads = num_attention_heads + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * num_layers + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + SpatioTemporalResBlock( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + ) + ) + attentions.append( + TransformerSpatioTemporalModel( + num_attention_heads, + out_channels // num_attention_heads, + in_channels=out_channels, + num_layers=transformer_layers_per_block[i], + cross_attention_dim=cross_attention_dim, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + self.resolution_idx = resolution_idx + + def forward( + self, + hidden_states: torch.FloatTensor, + res_hidden_states_tuple: Tuple[torch.FloatTensor, ...], + temb: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + image_only_indicator: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: # TODO + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(resnet), + hidden_states, + temb, + image_only_indicator, + **ckpt_kwargs, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + else: + hidden_states = resnet( + hidden_states, + temb, + image_only_indicator=image_only_indicator, + ) + hidden_states = attn( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + image_only_indicator=image_only_indicator, + return_dict=False, + )[0] + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states) + + return hidden_states diff --git a/libs/unet_motion_model.py b/libs/unet_motion_model.py new file mode 100644 index 0000000000000000000000000000000000000000..b8cddd585169350fdea85f243ac395a985ca0904 --- /dev/null +++ b/libs/unet_motion_model.py @@ -0,0 +1,975 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Dict, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import UNet2DConditionLoadersMixin +from diffusers.utils import logging, deprecate +from diffusers.models.attention_processor import ( + ADDED_KV_ATTENTION_PROCESSORS, + CROSS_ATTENTION_PROCESSORS, + AttentionProcessor, + AttnAddedKVProcessor, + AttnProcessor, +) +# from diffusers.models.controlnet import ControlNetConditioningEmbedding +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.transformers.transformer_temporal import TransformerTemporalModel +from diffusers.models.unets.unet_2d_blocks import UNetMidBlock2DCrossAttn +from .unet_2d_condition import UNet2DConditionModel +from .unet_3d_blocks import ( + CrossAttnDownBlockMotion, + CrossAttnUpBlockMotion, + DownBlockMotion, + UNetMidBlockCrossAttnMotion, + UpBlockMotion, + get_down_block, + get_up_block, +) +from diffusers.models.unets.unet_3d_condition import UNet3DConditionOutput + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class MotionModules(nn.Module): + def __init__( + self, + in_channels: int, + layers_per_block: int = 2, + num_attention_heads: int = 8, + attention_bias: bool = False, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + norm_num_groups: int = 32, + max_seq_length: int = 32, + ): + super().__init__() + self.motion_modules = nn.ModuleList([]) + + for i in range(layers_per_block): + self.motion_modules.append( + TransformerTemporalModel( + in_channels=in_channels, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + num_attention_heads=num_attention_heads, + attention_head_dim=in_channels // num_attention_heads, + positional_embeddings="sinusoidal", + num_positional_embeddings=max_seq_length, + ) + ) + + +class MotionAdapter(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + motion_layers_per_block: int = 2, + motion_mid_block_layers_per_block: int = 1, + motion_num_attention_heads: int = 8, + motion_norm_num_groups: int = 32, + motion_max_seq_length: int = 32, + use_motion_mid_block: bool = True, + ): + """Container to store AnimateDiff Motion Modules + + Args: + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each UNet block. + motion_layers_per_block (`int`, *optional*, defaults to 2): + The number of motion layers per UNet block. + motion_mid_block_layers_per_block (`int`, *optional*, defaults to 1): + The number of motion layers in the middle UNet block. + motion_num_attention_heads (`int`, *optional*, defaults to 8): + The number of heads to use in each attention layer of the motion module. + motion_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use in each group normalization layer of the motion module. + motion_max_seq_length (`int`, *optional*, defaults to 32): + The maximum sequence length to use in the motion module. + use_motion_mid_block (`bool`, *optional*, defaults to True): + Whether to use a motion module in the middle of the UNet. + """ + + super().__init__() + down_blocks = [] + up_blocks = [] + + for i, channel in enumerate(block_out_channels): + output_channel = block_out_channels[i] + down_blocks.append( + MotionModules( + in_channels=output_channel, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=None, + activation_fn="geglu", + attention_bias=False, + num_attention_heads=motion_num_attention_heads, + max_seq_length=motion_max_seq_length, + layers_per_block=motion_layers_per_block, + ) + ) + + if use_motion_mid_block: + self.mid_block = MotionModules( + in_channels=block_out_channels[-1], + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=None, + activation_fn="geglu", + attention_bias=False, + num_attention_heads=motion_num_attention_heads, + layers_per_block=motion_mid_block_layers_per_block, + max_seq_length=motion_max_seq_length, + ) + else: + self.mid_block = None + + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i, channel in enumerate(reversed_block_out_channels): + output_channel = reversed_block_out_channels[i] + up_blocks.append( + MotionModules( + in_channels=output_channel, + norm_num_groups=motion_norm_num_groups, + cross_attention_dim=None, + activation_fn="geglu", + attention_bias=False, + num_attention_heads=motion_num_attention_heads, + max_seq_length=motion_max_seq_length, + layers_per_block=motion_layers_per_block + 1, + ) + ) + + self.down_blocks = nn.ModuleList(down_blocks) + self.up_blocks = nn.ModuleList(up_blocks) + + def forward(self, sample): + pass + + +class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin): + r""" + A modified conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a + sample shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + conditioning_channels: int = 3, + out_channels: int = 4, + down_block_types: Tuple[str, ...] = ( + "CrossAttnDownBlockMotion", + "CrossAttnDownBlockMotion", + "CrossAttnDownBlockMotion", + "DownBlockMotion", + ), + mid_block_type: Optional[str] = "UNetMidBlockCrossAttnMotion", + up_block_types: Tuple[str, ...] = ( + "UpBlockMotion", + "CrossAttnUpBlockMotion", + "CrossAttnUpBlockMotion", + "CrossAttnUpBlockMotion", + ), + block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + use_linear_projection: bool = False, + num_attention_heads: Union[int, Tuple[int, ...]] = 8, + motion_max_seq_length: int = 32, + motion_num_attention_heads: int = 8, + use_motion_mid_block: int = True, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + conditioning_embedding_out_channels: Optional[Tuple[int]] = (16, 32, 96, 256), + ): + super().__init__() + + self.sample_size = sample_size + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + # input + conv_in_kernel = 3 + conv_out_kernel = 3 + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + time_embed_dim = block_out_channels[0] * 4 + self.time_proj = Timesteps(block_out_channels[0], True, 0) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + ) + + if encoder_hid_dim_type is None: + self.encoder_hid_proj = None + + # control net conditioning embedding + # self.controlnet_cond_embedding = ControlNetConditioningEmbedding( + # conditioning_embedding_channels=block_out_channels[0], + # block_out_channels=conditioning_embedding_out_channels, + # conditioning_channels=conditioning_channels, + # ) + + # class embedding + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + use_linear_projection=use_linear_projection, + dual_cross_attention=False, + temporal_num_attention_heads=motion_num_attention_heads, + temporal_max_seq_length=motion_max_seq_length, + ) + self.down_blocks.append(down_block) + + # mid + if use_motion_mid_block: + self.mid_block = UNetMidBlockCrossAttnMotion( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + temporal_num_attention_heads=motion_num_attention_heads, + temporal_max_seq_length=motion_max_seq_length, + ) + + else: + self.mid_block = UNetMidBlock2DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=False, + ) + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=False, + resolution_idx=i, + use_linear_projection=use_linear_projection, + temporal_num_attention_heads=motion_num_attention_heads, + temporal_max_seq_length=motion_max_seq_length, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + self.conv_act = nn.SiLU() + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + @classmethod + def from_unet2d( + cls, + unet: UNet2DConditionModel, + motion_adapter: Optional[MotionAdapter] = None, + load_weights: bool = True, + ): + has_motion_adapter = motion_adapter is not None + + # based on https://github.com/guoyww/AnimateDiff/blob/895f3220c06318ea0760131ec70408b466c49333/animatediff/models/unet.py#L459 + config = unet.config + config["_class_name"] = cls.__name__ + + down_blocks = [] + for down_blocks_type in config["down_block_types"]: + if "CrossAttn" in down_blocks_type: + down_blocks.append("CrossAttnDownBlockMotion") + else: + down_blocks.append("DownBlockMotion") + config["down_block_types"] = down_blocks + + up_blocks = [] + for down_blocks_type in config["up_block_types"]: + if "CrossAttn" in down_blocks_type: + up_blocks.append("CrossAttnUpBlockMotion") + else: + up_blocks.append("UpBlockMotion") + + config["up_block_types"] = up_blocks + + if has_motion_adapter: + config["motion_num_attention_heads"] = motion_adapter.config["motion_num_attention_heads"] + config["motion_max_seq_length"] = motion_adapter.config["motion_max_seq_length"] + config["use_motion_mid_block"] = motion_adapter.config["use_motion_mid_block"] + + # Need this for backwards compatibility with UNet2DConditionModel checkpoints + if not config.get("num_attention_heads"): + config["num_attention_heads"] = config["attention_head_dim"] + + model = cls.from_config(config) + + if not load_weights: + return model + + model.conv_in.load_state_dict(unet.conv_in.state_dict()) + model.time_proj.load_state_dict(unet.time_proj.state_dict()) + model.time_embedding.load_state_dict(unet.time_embedding.state_dict()) + # model.controlnet_cond_embedding.load_state_dict(unet.controlnet_cond_embedding.state_dict()) # pose guider + + for i, down_block in enumerate(unet.down_blocks): + model.down_blocks[i].resnets.load_state_dict(down_block.resnets.state_dict()) + if hasattr(model.down_blocks[i], "attentions"): + model.down_blocks[i].attentions.load_state_dict(down_block.attentions.state_dict()) + if model.down_blocks[i].downsamplers: + model.down_blocks[i].downsamplers.load_state_dict(down_block.downsamplers.state_dict()) + + for i, up_block in enumerate(unet.up_blocks): + model.up_blocks[i].resnets.load_state_dict(up_block.resnets.state_dict()) + if hasattr(model.up_blocks[i], "attentions"): + model.up_blocks[i].attentions.load_state_dict(up_block.attentions.state_dict()) + if model.up_blocks[i].upsamplers: + model.up_blocks[i].upsamplers.load_state_dict(up_block.upsamplers.state_dict()) + + model.mid_block.resnets.load_state_dict(unet.mid_block.resnets.state_dict()) + model.mid_block.attentions.load_state_dict(unet.mid_block.attentions.state_dict()) + + if unet.conv_norm_out is not None: + model.conv_norm_out.load_state_dict(unet.conv_norm_out.state_dict()) + if unet.conv_act is not None: + model.conv_act.load_state_dict(unet.conv_act.state_dict()) + model.conv_out.load_state_dict(unet.conv_out.state_dict()) + + if has_motion_adapter: + model.load_motion_modules(motion_adapter) + + # ensure that the Motion UNet is the same dtype as the UNet2DConditionModel + model.to(unet.dtype) + + return model + + def freeze_unet2d_params(self) -> None: + """Freeze the weights of just the UNet2DConditionModel, and leave the motion modules + unfrozen for fine tuning. + """ + # Freeze everything + for param in self.parameters(): + param.requires_grad = False + + # Unfreeze Motion Modules + for down_block in self.down_blocks: + motion_modules = down_block.motion_modules + for param in motion_modules.parameters(): + param.requires_grad = True + + for up_block in self.up_blocks: + motion_modules = up_block.motion_modules + for param in motion_modules.parameters(): + param.requires_grad = True + + if hasattr(self.mid_block, "motion_modules"): + motion_modules = self.mid_block.motion_modules + for param in motion_modules.parameters(): + param.requires_grad = True + + def load_motion_modules(self, motion_adapter: Optional[MotionAdapter]) -> None: + for i, down_block in enumerate(motion_adapter.down_blocks): + self.down_blocks[i].motion_modules.load_state_dict(down_block.motion_modules.state_dict()) + for i, up_block in enumerate(motion_adapter.up_blocks): + self.up_blocks[i].motion_modules.load_state_dict(up_block.motion_modules.state_dict()) + + # to support older motion modules that don't have a mid_block + if hasattr(self.mid_block, "motion_modules"): + self.mid_block.motion_modules.load_state_dict(motion_adapter.mid_block.motion_modules.state_dict()) + + def save_motion_modules( + self, + save_directory: str, + is_main_process: bool = True, + safe_serialization: bool = True, + variant: Optional[str] = None, + push_to_hub: bool = False, + **kwargs, + ) -> None: + state_dict = self.state_dict() + + # Extract all motion modules + motion_state_dict = {} + for k, v in state_dict.items(): + if "motion_modules" in k: + motion_state_dict[k] = v + + adapter = MotionAdapter( + block_out_channels=self.config["block_out_channels"], + motion_layers_per_block=self.config["layers_per_block"], + motion_norm_num_groups=self.config["norm_num_groups"], + motion_num_attention_heads=self.config["motion_num_attention_heads"], + motion_max_seq_length=self.config["motion_max_seq_length"], + use_motion_mid_block=self.config["use_motion_mid_block"], + ) + adapter.load_state_dict(motion_state_dict) + adapter.save_pretrained( + save_directory=save_directory, + is_main_process=is_main_process, + safe_serialization=safe_serialization, + variant=variant, + push_to_hub=push_to_hub, + **kwargs, + ) + + @property + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.attn_processors + def attn_processors(self) -> Dict[str, AttentionProcessor]: + r""" + Returns: + `dict` of attention processors: A dictionary containing all attention processors used in the model with + indexed by its weight name. + """ + # set recursively + processors = {} + + def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]): + if hasattr(module, "get_processor"): + processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True) + + for sub_name, child in module.named_children(): + fn_recursive_add_processors(f"{name}.{sub_name}", child, processors) + + return processors + + for name, module in self.named_children(): + fn_recursive_add_processors(name, module, processors) + + return processors + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_attn_processor + def set_attn_processor( + self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]], _remove_lora=False + ): + r""" + Sets the attention processor to use to compute attention. + + Parameters: + processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`): + The instantiated processor class or a dictionary of processor classes that will be set as the processor + for **all** `Attention` layers. + + If `processor` is a dict, the key needs to define the path to the corresponding cross attention + processor. This is strongly recommended when setting trainable attention processors. + + """ + count = len(self.attn_processors.keys()) + + if isinstance(processor, dict) and len(processor) != count: + raise ValueError( + f"A dict of processors was passed, but the number of processors {len(processor)} does not match the" + f" number of attention layers: {count}. Please make sure to pass {count} processor classes." + ) + + def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor): + if hasattr(module, "set_processor"): + if not isinstance(processor, dict): + module.set_processor(processor, _remove_lora=_remove_lora) + else: + module.set_processor(processor.pop(f"{name}.processor"), _remove_lora=_remove_lora) + + for sub_name, child in module.named_children(): + fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor) + + for name, module in self.named_children(): + fn_recursive_attn_processor(name, module, processor) + + # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking + def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None: + """ + Sets the attention processor to use [feed forward + chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers). + + Parameters: + chunk_size (`int`, *optional*): + The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually + over each tensor of dim=`dim`. + dim (`int`, *optional*, defaults to `0`): + The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch) + or dim=1 (sequence length). + """ + if dim not in [0, 1]: + raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}") + + # By default chunk size is 1 + chunk_size = chunk_size or 1 + + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, chunk_size, dim) + + # Copied from diffusers.models.unet_3d_condition.UNet3DConditionModel.disable_forward_chunking + def disable_forward_chunking(self) -> None: + def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int): + if hasattr(module, "set_chunk_feed_forward"): + module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim) + + for child in module.children(): + fn_recursive_feed_forward(child, chunk_size, dim) + + for module in self.children(): + fn_recursive_feed_forward(module, None, 0) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor + def set_default_attn_processor(self) -> None: + """ + Disables custom attention processors and sets the default attention implementation. + """ + if all(proc.__class__ in ADDED_KV_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnAddedKVProcessor() + elif all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()): + processor = AttnProcessor() + else: + raise ValueError( + f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}" + ) + + self.set_attn_processor(processor, _remove_lora=True) + + def _set_gradient_checkpointing(self, module, value: bool = False) -> None: + if isinstance(module, (CrossAttnDownBlockMotion, DownBlockMotion, CrossAttnUpBlockMotion, UpBlockMotion)): + module.gradient_checkpointing = value + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.enable_freeu + def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None: + r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497. + + The suffixes after the scaling factors represent the stage blocks where they are being applied. + + Please refer to the [official repository](https://github.com/ChenyangSi/FreeU) for combinations of values that + are known to work well for different pipelines such as Stable Diffusion v1, v2, and Stable Diffusion XL. + + Args: + s1 (`float`): + Scaling factor for stage 1 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + s2 (`float`): + Scaling factor for stage 2 to attenuate the contributions of the skip features. This is done to + mitigate the "oversmoothing effect" in the enhanced denoising process. + b1 (`float`): Scaling factor for stage 1 to amplify the contributions of backbone features. + b2 (`float`): Scaling factor for stage 2 to amplify the contributions of backbone features. + """ + for i, upsample_block in enumerate(self.up_blocks): + setattr(upsample_block, "s1", s1) + setattr(upsample_block, "s2", s2) + setattr(upsample_block, "b1", b1) + setattr(upsample_block, "b2", b2) + + # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel.disable_freeu + def disable_freeu(self) -> None: + """Disables the FreeU mechanism.""" + freeu_keys = {"s1", "s2", "b1", "b2"} + for i, upsample_block in enumerate(self.up_blocks): + for k in freeu_keys: + if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None: + setattr(upsample_block, k, None) + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + # controlnet_cond: torch.FloatTensor, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + return_dict: bool = True, + num_frames: int = 24, + down_block_add_samples: Optional[Tuple[torch.Tensor]] = None, + mid_block_add_sample: Optional[Tuple[torch.Tensor]] = None, + up_block_add_samples: Optional[Tuple[torch.Tensor]] = None, + ) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]: + r""" + The [`UNetMotionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch * num_frames, channel, height, width`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_3d_condition.UNet3DConditionOutput`] instead of a plain + tuple. + + Returns: + [`~models.unet_3d_condition.UNet3DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unet_3d_condition.UNet3DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0] // num_frames) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + emb = emb.repeat_interleave(repeats=num_frames, dim=0) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype) + encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1) + + encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0) + + # 2. pre-process + # sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:]) + # N*T C H W + sample = self.conv_in(sample) + # controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) + # sample += controlnet_cond + + # 3. down + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + is_brushnet = down_block_add_samples is not None and mid_block_add_sample is not None and up_block_add_samples is not None + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + if is_brushnet: + sample = sample + down_block_add_samples.pop(0) + + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + if is_brushnet and len(down_block_add_samples)>0: + additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0) + for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))] + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + **additional_residuals, + ) + else: + additional_residuals = {} + if is_brushnet and len(down_block_add_samples)>0: + additional_residuals["down_block_add_samples"] = [down_block_add_samples.pop(0) + for _ in range(len(downsample_block.resnets)+(downsample_block.downsamplers !=None))] + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, num_frames=num_frames, **additional_residuals,) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + if down_block_additional_residuals is not None: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples += (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + # To support older versions of motion modules that don't have a mid_block + if hasattr(self.mid_block, "motion_modules"): + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + mid_block_add_sample=mid_block_add_sample, + ) + else: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + mid_block_add_sample=mid_block_add_sample, + ) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # if is_brushnet: + # sample = sample + mid_block_add_sample + + if mid_block_additional_residual is not None: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + additional_residuals = {} + if is_brushnet and len(up_block_add_samples)>0: + additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0) + for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))] + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + num_frames=num_frames, + cross_attention_kwargs=cross_attention_kwargs, + **additional_residuals, + ) + else: + additional_residuals = {} + if is_brushnet and len(up_block_add_samples)>0: + additional_residuals["up_block_add_samples"] = [up_block_add_samples.pop(0) + for _ in range(len(upsample_block.resnets)+(upsample_block.upsamplers !=None))] + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + num_frames=num_frames, + **additional_residuals, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + + sample = self.conv_out(sample) + + # reshape to (batch, framerate, channel, width, height) + # sample = sample[None, :].reshape((-1, num_frames) + sample.shape[1:]) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) diff --git a/propainter/RAFT/__init__.py b/propainter/RAFT/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e7179ea3ce4ad81425c619772d4bc47bc7ceea3a --- /dev/null +++ b/propainter/RAFT/__init__.py @@ -0,0 +1,2 @@ +# from .demo import RAFT_infer +from .raft import RAFT diff --git a/propainter/RAFT/corr.py b/propainter/RAFT/corr.py new file mode 100644 index 0000000000000000000000000000000000000000..449dbd963b8303eda242a65063ca857b95475721 --- /dev/null +++ b/propainter/RAFT/corr.py @@ -0,0 +1,111 @@ +import torch +import torch.nn.functional as F +from .utils.utils import bilinear_sampler, coords_grid + +try: + import alt_cuda_corr +except: + # alt_cuda_corr is not compiled + pass + + +class CorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + self.corr_pyramid = [] + + # all pairs correlation + corr = CorrBlock.corr(fmap1, fmap2) + + batch, h1, w1, dim, h2, w2 = corr.shape + corr = corr.reshape(batch*h1*w1, dim, h2, w2) + + self.corr_pyramid.append(corr) + for i in range(self.num_levels-1): + corr = F.avg_pool2d(corr, 2, stride=2) + self.corr_pyramid.append(corr) + + def __call__(self, coords): + r = self.radius + coords = coords.permute(0, 2, 3, 1) + batch, h1, w1, _ = coords.shape + + out_pyramid = [] + for i in range(self.num_levels): + corr = self.corr_pyramid[i] + dx = torch.linspace(-r, r, 2*r+1) + dy = torch.linspace(-r, r, 2*r+1) + delta = torch.stack(torch.meshgrid(dy, dx), axis=-1).to(coords.device) + + centroid_lvl = coords.reshape(batch*h1*w1, 1, 1, 2) / 2**i + delta_lvl = delta.view(1, 2*r+1, 2*r+1, 2) + coords_lvl = centroid_lvl + delta_lvl + + corr = bilinear_sampler(corr, coords_lvl) + corr = corr.view(batch, h1, w1, -1) + out_pyramid.append(corr) + + out = torch.cat(out_pyramid, dim=-1) + return out.permute(0, 3, 1, 2).contiguous().float() + + @staticmethod + def corr(fmap1, fmap2): + batch, dim, ht, wd = fmap1.shape + fmap1 = fmap1.view(batch, dim, ht*wd) + fmap2 = fmap2.view(batch, dim, ht*wd) + + corr = torch.matmul(fmap1.transpose(1,2), fmap2) + corr = corr.view(batch, ht, wd, 1, ht, wd) + return corr / torch.sqrt(torch.tensor(dim).float()) + + +class CorrLayer(torch.autograd.Function): + @staticmethod + def forward(ctx, fmap1, fmap2, coords, r): + fmap1 = fmap1.contiguous() + fmap2 = fmap2.contiguous() + coords = coords.contiguous() + ctx.save_for_backward(fmap1, fmap2, coords) + ctx.r = r + corr, = correlation_cudaz.forward(fmap1, fmap2, coords, ctx.r) + return corr + + @staticmethod + def backward(ctx, grad_corr): + fmap1, fmap2, coords = ctx.saved_tensors + grad_corr = grad_corr.contiguous() + fmap1_grad, fmap2_grad, coords_grad = \ + correlation_cudaz.backward(fmap1, fmap2, coords, grad_corr, ctx.r) + return fmap1_grad, fmap2_grad, coords_grad, None + + +class AlternateCorrBlock: + def __init__(self, fmap1, fmap2, num_levels=4, radius=4): + self.num_levels = num_levels + self.radius = radius + + self.pyramid = [(fmap1, fmap2)] + for i in range(self.num_levels): + fmap1 = F.avg_pool2d(fmap1, 2, stride=2) + fmap2 = F.avg_pool2d(fmap2, 2, stride=2) + self.pyramid.append((fmap1, fmap2)) + + def __call__(self, coords): + + coords = coords.permute(0, 2, 3, 1) + B, H, W, _ = coords.shape + + corr_list = [] + for i in range(self.num_levels): + r = self.radius + fmap1_i = self.pyramid[0][0].permute(0, 2, 3, 1) + fmap2_i = self.pyramid[i][1].permute(0, 2, 3, 1) + + coords_i = (coords / 2**i).reshape(B, 1, H, W, 2).contiguous() + corr = alt_cuda_corr(fmap1_i, fmap2_i, coords_i, r) + corr_list.append(corr.squeeze(1)) + + corr = torch.stack(corr_list, dim=1) + corr = corr.reshape(B, -1, H, W) + return corr / 16.0 diff --git a/propainter/RAFT/datasets.py b/propainter/RAFT/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..3411fdacfb900024005e8997d07c600e963a95ca --- /dev/null +++ b/propainter/RAFT/datasets.py @@ -0,0 +1,235 @@ +# Data loading based on https://github.com/NVIDIA/flownet2-pytorch + +import numpy as np +import torch +import torch.utils.data as data +import torch.nn.functional as F + +import os +import math +import random +from glob import glob +import os.path as osp + +from utils import frame_utils +from utils.augmentor import FlowAugmentor, SparseFlowAugmentor + + +class FlowDataset(data.Dataset): + def __init__(self, aug_params=None, sparse=False): + self.augmentor = None + self.sparse = sparse + if aug_params is not None: + if sparse: + self.augmentor = SparseFlowAugmentor(**aug_params) + else: + self.augmentor = FlowAugmentor(**aug_params) + + self.is_test = False + self.init_seed = False + self.flow_list = [] + self.image_list = [] + self.extra_info = [] + + def __getitem__(self, index): + + if self.is_test: + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + img1 = np.array(img1).astype(np.uint8)[..., :3] + img2 = np.array(img2).astype(np.uint8)[..., :3] + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + return img1, img2, self.extra_info[index] + + if not self.init_seed: + worker_info = torch.utils.data.get_worker_info() + if worker_info is not None: + torch.manual_seed(worker_info.id) + np.random.seed(worker_info.id) + random.seed(worker_info.id) + self.init_seed = True + + index = index % len(self.image_list) + valid = None + if self.sparse: + flow, valid = frame_utils.readFlowKITTI(self.flow_list[index]) + else: + flow = frame_utils.read_gen(self.flow_list[index]) + + img1 = frame_utils.read_gen(self.image_list[index][0]) + img2 = frame_utils.read_gen(self.image_list[index][1]) + + flow = np.array(flow).astype(np.float32) + img1 = np.array(img1).astype(np.uint8) + img2 = np.array(img2).astype(np.uint8) + + # grayscale images + if len(img1.shape) == 2: + img1 = np.tile(img1[...,None], (1, 1, 3)) + img2 = np.tile(img2[...,None], (1, 1, 3)) + else: + img1 = img1[..., :3] + img2 = img2[..., :3] + + if self.augmentor is not None: + if self.sparse: + img1, img2, flow, valid = self.augmentor(img1, img2, flow, valid) + else: + img1, img2, flow = self.augmentor(img1, img2, flow) + + img1 = torch.from_numpy(img1).permute(2, 0, 1).float() + img2 = torch.from_numpy(img2).permute(2, 0, 1).float() + flow = torch.from_numpy(flow).permute(2, 0, 1).float() + + if valid is not None: + valid = torch.from_numpy(valid) + else: + valid = (flow[0].abs() < 1000) & (flow[1].abs() < 1000) + + return img1, img2, flow, valid.float() + + + def __rmul__(self, v): + self.flow_list = v * self.flow_list + self.image_list = v * self.image_list + return self + + def __len__(self): + return len(self.image_list) + + +class MpiSintel(FlowDataset): + def __init__(self, aug_params=None, split='training', root='datasets/Sintel', dstype='clean'): + super(MpiSintel, self).__init__(aug_params) + flow_root = osp.join(root, split, 'flow') + image_root = osp.join(root, split, dstype) + + if split == 'test': + self.is_test = True + + for scene in os.listdir(image_root): + image_list = sorted(glob(osp.join(image_root, scene, '*.png'))) + for i in range(len(image_list)-1): + self.image_list += [ [image_list[i], image_list[i+1]] ] + self.extra_info += [ (scene, i) ] # scene and frame_id + + if split != 'test': + self.flow_list += sorted(glob(osp.join(flow_root, scene, '*.flo'))) + + +class FlyingChairs(FlowDataset): + def __init__(self, aug_params=None, split='train', root='datasets/FlyingChairs_release/data'): + super(FlyingChairs, self).__init__(aug_params) + + images = sorted(glob(osp.join(root, '*.ppm'))) + flows = sorted(glob(osp.join(root, '*.flo'))) + assert (len(images)//2 == len(flows)) + + split_list = np.loadtxt('chairs_split.txt', dtype=np.int32) + for i in range(len(flows)): + xid = split_list[i] + if (split=='training' and xid==1) or (split=='validation' and xid==2): + self.flow_list += [ flows[i] ] + self.image_list += [ [images[2*i], images[2*i+1]] ] + + +class FlyingThings3D(FlowDataset): + def __init__(self, aug_params=None, root='datasets/FlyingThings3D', dstype='frames_cleanpass'): + super(FlyingThings3D, self).__init__(aug_params) + + for cam in ['left']: + for direction in ['into_future', 'into_past']: + image_dirs = sorted(glob(osp.join(root, dstype, 'TRAIN/*/*'))) + image_dirs = sorted([osp.join(f, cam) for f in image_dirs]) + + flow_dirs = sorted(glob(osp.join(root, 'optical_flow/TRAIN/*/*'))) + flow_dirs = sorted([osp.join(f, direction, cam) for f in flow_dirs]) + + for idir, fdir in zip(image_dirs, flow_dirs): + images = sorted(glob(osp.join(idir, '*.png')) ) + flows = sorted(glob(osp.join(fdir, '*.pfm')) ) + for i in range(len(flows)-1): + if direction == 'into_future': + self.image_list += [ [images[i], images[i+1]] ] + self.flow_list += [ flows[i] ] + elif direction == 'into_past': + self.image_list += [ [images[i+1], images[i]] ] + self.flow_list += [ flows[i+1] ] + + +class KITTI(FlowDataset): + def __init__(self, aug_params=None, split='training', root='datasets/KITTI'): + super(KITTI, self).__init__(aug_params, sparse=True) + if split == 'testing': + self.is_test = True + + root = osp.join(root, split) + images1 = sorted(glob(osp.join(root, 'image_2/*_10.png'))) + images2 = sorted(glob(osp.join(root, 'image_2/*_11.png'))) + + for img1, img2 in zip(images1, images2): + frame_id = img1.split('/')[-1] + self.extra_info += [ [frame_id] ] + self.image_list += [ [img1, img2] ] + + if split == 'training': + self.flow_list = sorted(glob(osp.join(root, 'flow_occ/*_10.png'))) + + +class HD1K(FlowDataset): + def __init__(self, aug_params=None, root='datasets/HD1k'): + super(HD1K, self).__init__(aug_params, sparse=True) + + seq_ix = 0 + while 1: + flows = sorted(glob(os.path.join(root, 'hd1k_flow_gt', 'flow_occ/%06d_*.png' % seq_ix))) + images = sorted(glob(os.path.join(root, 'hd1k_input', 'image_2/%06d_*.png' % seq_ix))) + + if len(flows) == 0: + break + + for i in range(len(flows)-1): + self.flow_list += [flows[i]] + self.image_list += [ [images[i], images[i+1]] ] + + seq_ix += 1 + + +def fetch_dataloader(args, TRAIN_DS='C+T+K+S+H'): + """ Create the data loader for the corresponding trainign set """ + + if args.stage == 'chairs': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.1, 'max_scale': 1.0, 'do_flip': True} + train_dataset = FlyingChairs(aug_params, split='training') + + elif args.stage == 'things': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.4, 'max_scale': 0.8, 'do_flip': True} + clean_dataset = FlyingThings3D(aug_params, dstype='frames_cleanpass') + final_dataset = FlyingThings3D(aug_params, dstype='frames_finalpass') + train_dataset = clean_dataset + final_dataset + + elif args.stage == 'sintel': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.6, 'do_flip': True} + things = FlyingThings3D(aug_params, dstype='frames_cleanpass') + sintel_clean = MpiSintel(aug_params, split='training', dstype='clean') + sintel_final = MpiSintel(aug_params, split='training', dstype='final') + + if TRAIN_DS == 'C+T+K+S+H': + kitti = KITTI({'crop_size': args.image_size, 'min_scale': -0.3, 'max_scale': 0.5, 'do_flip': True}) + hd1k = HD1K({'crop_size': args.image_size, 'min_scale': -0.5, 'max_scale': 0.2, 'do_flip': True}) + train_dataset = 100*sintel_clean + 100*sintel_final + 200*kitti + 5*hd1k + things + + elif TRAIN_DS == 'C+T+K/S': + train_dataset = 100*sintel_clean + 100*sintel_final + things + + elif args.stage == 'kitti': + aug_params = {'crop_size': args.image_size, 'min_scale': -0.2, 'max_scale': 0.4, 'do_flip': False} + train_dataset = KITTI(aug_params, split='training') + + train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size, + pin_memory=False, shuffle=True, num_workers=4, drop_last=True) + + print('Training with %d image pairs' % len(train_dataset)) + return train_loader + diff --git a/propainter/RAFT/demo.py b/propainter/RAFT/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..096963bdbb36aed3df673f131d6e044d8c6f95ea --- /dev/null +++ b/propainter/RAFT/demo.py @@ -0,0 +1,79 @@ +import sys +import argparse +import os +import cv2 +import glob +import numpy as np +import torch +from PIL import Image + +from .raft import RAFT +from .utils import flow_viz +from .utils.utils import InputPadder + + + +DEVICE = 'cuda' + +def load_image(imfile): + img = np.array(Image.open(imfile)).astype(np.uint8) + img = torch.from_numpy(img).permute(2, 0, 1).float() + return img + + +def load_image_list(image_files): + images = [] + for imfile in sorted(image_files): + images.append(load_image(imfile)) + + images = torch.stack(images, dim=0) + images = images.to(DEVICE) + + padder = InputPadder(images.shape) + return padder.pad(images)[0] + + +def viz(img, flo): + img = img[0].permute(1,2,0).cpu().numpy() + flo = flo[0].permute(1,2,0).cpu().numpy() + + # map flow to rgb image + flo = flow_viz.flow_to_image(flo) + # img_flo = np.concatenate([img, flo], axis=0) + img_flo = flo + + cv2.imwrite('/home/chengao/test/flow.png', img_flo[:, :, [2,1,0]]) + # cv2.imshow('image', img_flo[:, :, [2,1,0]]/255.0) + # cv2.waitKey() + + +def demo(args): + model = torch.nn.DataParallel(RAFT(args)) + model.load_state_dict(torch.load(args.model)) + + model = model.module + model.to(DEVICE) + model.eval() + + with torch.no_grad(): + images = glob.glob(os.path.join(args.path, '*.png')) + \ + glob.glob(os.path.join(args.path, '*.jpg')) + + images = load_image_list(images) + for i in range(images.shape[0]-1): + image1 = images[i,None] + image2 = images[i+1,None] + + flow_low, flow_up = model(image1, image2, iters=20, test_mode=True) + viz(image1, flow_up) + + +def RAFT_infer(args): + model = torch.nn.DataParallel(RAFT(args)) + model.load_state_dict(torch.load(args.model)) + + model = model.module + model.to(DEVICE) + model.eval() + + return model diff --git a/propainter/RAFT/extractor.py b/propainter/RAFT/extractor.py new file mode 100644 index 0000000000000000000000000000000000000000..9a9c759d1243d4694e8656c2f6f8a37e53edd009 --- /dev/null +++ b/propainter/RAFT/extractor.py @@ -0,0 +1,267 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class ResidualBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(ResidualBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, padding=1, stride=stride) + self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes) + self.norm2 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm3 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes) + self.norm2 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm3 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + if not stride == 1: + self.norm3 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm3) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + + + +class BottleneckBlock(nn.Module): + def __init__(self, in_planes, planes, norm_fn='group', stride=1): + super(BottleneckBlock, self).__init__() + + self.conv1 = nn.Conv2d(in_planes, planes//4, kernel_size=1, padding=0) + self.conv2 = nn.Conv2d(planes//4, planes//4, kernel_size=3, padding=1, stride=stride) + self.conv3 = nn.Conv2d(planes//4, planes, kernel_size=1, padding=0) + self.relu = nn.ReLU(inplace=True) + + num_groups = planes // 8 + + if norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes//4) + self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + if not stride == 1: + self.norm4 = nn.GroupNorm(num_groups=num_groups, num_channels=planes) + + elif norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(planes//4) + self.norm2 = nn.BatchNorm2d(planes//4) + self.norm3 = nn.BatchNorm2d(planes) + if not stride == 1: + self.norm4 = nn.BatchNorm2d(planes) + + elif norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(planes//4) + self.norm2 = nn.InstanceNorm2d(planes//4) + self.norm3 = nn.InstanceNorm2d(planes) + if not stride == 1: + self.norm4 = nn.InstanceNorm2d(planes) + + elif norm_fn == 'none': + self.norm1 = nn.Sequential() + self.norm2 = nn.Sequential() + self.norm3 = nn.Sequential() + if not stride == 1: + self.norm4 = nn.Sequential() + + if stride == 1: + self.downsample = None + + else: + self.downsample = nn.Sequential( + nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride), self.norm4) + + + def forward(self, x): + y = x + y = self.relu(self.norm1(self.conv1(y))) + y = self.relu(self.norm2(self.conv2(y))) + y = self.relu(self.norm3(self.conv3(y))) + + if self.downsample is not None: + x = self.downsample(x) + + return self.relu(x+y) + +class BasicEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(BasicEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=64) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(64) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(64) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 64 + self.layer1 = self._make_layer(64, stride=1) + self.layer2 = self._make_layer(96, stride=2) + self.layer3 = self._make_layer(128, stride=2) + + # output convolution + self.conv2 = nn.Conv2d(128, output_dim, kernel_size=1) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = ResidualBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = ResidualBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x + + +class SmallEncoder(nn.Module): + def __init__(self, output_dim=128, norm_fn='batch', dropout=0.0): + super(SmallEncoder, self).__init__() + self.norm_fn = norm_fn + + if self.norm_fn == 'group': + self.norm1 = nn.GroupNorm(num_groups=8, num_channels=32) + + elif self.norm_fn == 'batch': + self.norm1 = nn.BatchNorm2d(32) + + elif self.norm_fn == 'instance': + self.norm1 = nn.InstanceNorm2d(32) + + elif self.norm_fn == 'none': + self.norm1 = nn.Sequential() + + self.conv1 = nn.Conv2d(3, 32, kernel_size=7, stride=2, padding=3) + self.relu1 = nn.ReLU(inplace=True) + + self.in_planes = 32 + self.layer1 = self._make_layer(32, stride=1) + self.layer2 = self._make_layer(64, stride=2) + self.layer3 = self._make_layer(96, stride=2) + + self.dropout = None + if dropout > 0: + self.dropout = nn.Dropout2d(p=dropout) + + self.conv2 = nn.Conv2d(96, output_dim, kernel_size=1) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.InstanceNorm2d, nn.GroupNorm)): + if m.weight is not None: + nn.init.constant_(m.weight, 1) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + def _make_layer(self, dim, stride=1): + layer1 = BottleneckBlock(self.in_planes, dim, self.norm_fn, stride=stride) + layer2 = BottleneckBlock(dim, dim, self.norm_fn, stride=1) + layers = (layer1, layer2) + + self.in_planes = dim + return nn.Sequential(*layers) + + + def forward(self, x): + + # if input is list, combine batch dimension + is_list = isinstance(x, tuple) or isinstance(x, list) + if is_list: + batch_dim = x[0].shape[0] + x = torch.cat(x, dim=0) + + x = self.conv1(x) + x = self.norm1(x) + x = self.relu1(x) + + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.conv2(x) + + if self.training and self.dropout is not None: + x = self.dropout(x) + + if is_list: + x = torch.split(x, [batch_dim, batch_dim], dim=0) + + return x diff --git a/propainter/RAFT/raft.py b/propainter/RAFT/raft.py new file mode 100644 index 0000000000000000000000000000000000000000..829ef97b8d3e280aac59ebef7bb2eaf06274b62a --- /dev/null +++ b/propainter/RAFT/raft.py @@ -0,0 +1,146 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .update import BasicUpdateBlock, SmallUpdateBlock +from .extractor import BasicEncoder, SmallEncoder +from .corr import CorrBlock, AlternateCorrBlock +from .utils.utils import bilinear_sampler, coords_grid, upflow8 + +try: + autocast = torch.cuda.amp.autocast +except: + # dummy autocast for PyTorch < 1.6 + class autocast: + def __init__(self, enabled): + pass + def __enter__(self): + pass + def __exit__(self, *args): + pass + + +class RAFT(nn.Module): + def __init__(self, args): + super(RAFT, self).__init__() + self.args = args + + if args.small: + self.hidden_dim = hdim = 96 + self.context_dim = cdim = 64 + args.corr_levels = 4 + args.corr_radius = 3 + + else: + self.hidden_dim = hdim = 128 + self.context_dim = cdim = 128 + args.corr_levels = 4 + args.corr_radius = 4 + + if 'dropout' not in args._get_kwargs(): + args.dropout = 0 + + if 'alternate_corr' not in args._get_kwargs(): + args.alternate_corr = False + + # feature network, context network, and update block + if args.small: + self.fnet = SmallEncoder(output_dim=128, norm_fn='instance', dropout=args.dropout) + self.cnet = SmallEncoder(output_dim=hdim+cdim, norm_fn='none', dropout=args.dropout) + self.update_block = SmallUpdateBlock(self.args, hidden_dim=hdim) + + else: + self.fnet = BasicEncoder(output_dim=256, norm_fn='instance', dropout=args.dropout) + self.cnet = BasicEncoder(output_dim=hdim+cdim, norm_fn='batch', dropout=args.dropout) + self.update_block = BasicUpdateBlock(self.args, hidden_dim=hdim) + + + def freeze_bn(self): + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + + def initialize_flow(self, img): + """ Flow is represented as difference between two coordinate grids flow = coords1 - coords0""" + N, C, H, W = img.shape + coords0 = coords_grid(N, H//8, W//8).to(img.device) + coords1 = coords_grid(N, H//8, W//8).to(img.device) + + # optical flow computed as difference: flow = coords1 - coords0 + return coords0, coords1 + + def upsample_flow(self, flow, mask): + """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ + N, _, H, W = flow.shape + mask = mask.view(N, 1, 9, 8, 8, H, W) + mask = torch.softmax(mask, dim=2) + + up_flow = F.unfold(8 * flow, [3,3], padding=1) + up_flow = up_flow.view(N, 2, 9, 1, 1, H, W) + + up_flow = torch.sum(mask * up_flow, dim=2) + up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) + return up_flow.reshape(N, 2, 8*H, 8*W) + + + def forward(self, image1, image2, iters=12, flow_init=None, test_mode=True): + """ Estimate optical flow between pair of frames """ + + # image1 = 2 * (image1 / 255.0) - 1.0 + # image2 = 2 * (image2 / 255.0) - 1.0 + + image1 = image1.contiguous() + image2 = image2.contiguous() + + hdim = self.hidden_dim + cdim = self.context_dim + + # run the feature network + with autocast(enabled=self.args.mixed_precision): + fmap1, fmap2 = self.fnet([image1, image2]) + + fmap1 = fmap1.float() + fmap2 = fmap2.float() + + if self.args.alternate_corr: + corr_fn = AlternateCorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + else: + corr_fn = CorrBlock(fmap1, fmap2, radius=self.args.corr_radius) + + # run the context network + with autocast(enabled=self.args.mixed_precision): + cnet = self.cnet(image1) + net, inp = torch.split(cnet, [hdim, cdim], dim=1) + net = torch.tanh(net) + inp = torch.relu(inp) + + coords0, coords1 = self.initialize_flow(image1) + + if flow_init is not None: + coords1 = coords1 + flow_init + + flow_predictions = [] + for itr in range(iters): + coords1 = coords1.detach() + corr = corr_fn(coords1) # index correlation volume + + flow = coords1 - coords0 + with autocast(enabled=self.args.mixed_precision): + net, up_mask, delta_flow = self.update_block(net, inp, corr, flow) + + # F(t+1) = F(t) + \Delta(t) + coords1 = coords1 + delta_flow + + # upsample predictions + if up_mask is None: + flow_up = upflow8(coords1 - coords0) + else: + flow_up = self.upsample_flow(coords1 - coords0, up_mask) + + flow_predictions.append(flow_up) + + if test_mode: + return coords1 - coords0, flow_up + + return flow_predictions diff --git a/propainter/RAFT/update.py b/propainter/RAFT/update.py new file mode 100644 index 0000000000000000000000000000000000000000..f940497f9b5eb1c12091574fe9a0223a1b196d50 --- /dev/null +++ b/propainter/RAFT/update.py @@ -0,0 +1,139 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class FlowHead(nn.Module): + def __init__(self, input_dim=128, hidden_dim=256): + super(FlowHead, self).__init__() + self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) + self.conv2 = nn.Conv2d(hidden_dim, 2, 3, padding=1) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + return self.conv2(self.relu(self.conv1(x))) + +class ConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(ConvGRU, self).__init__() + self.convz = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convr = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + self.convq = nn.Conv2d(hidden_dim+input_dim, hidden_dim, 3, padding=1) + + def forward(self, h, x): + hx = torch.cat([h, x], dim=1) + + z = torch.sigmoid(self.convz(hx)) + r = torch.sigmoid(self.convr(hx)) + q = torch.tanh(self.convq(torch.cat([r*h, x], dim=1))) + + h = (1-z) * h + z * q + return h + +class SepConvGRU(nn.Module): + def __init__(self, hidden_dim=128, input_dim=192+128): + super(SepConvGRU, self).__init__() + self.convz1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convr1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + self.convq1 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (1,5), padding=(0,2)) + + self.convz2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convr2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + self.convq2 = nn.Conv2d(hidden_dim+input_dim, hidden_dim, (5,1), padding=(2,0)) + + + def forward(self, h, x): + # horizontal + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz1(hx)) + r = torch.sigmoid(self.convr1(hx)) + q = torch.tanh(self.convq1(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + # vertical + hx = torch.cat([h, x], dim=1) + z = torch.sigmoid(self.convz2(hx)) + r = torch.sigmoid(self.convr2(hx)) + q = torch.tanh(self.convq2(torch.cat([r*h, x], dim=1))) + h = (1-z) * h + z * q + + return h + +class SmallMotionEncoder(nn.Module): + def __init__(self, args): + super(SmallMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 96, 1, padding=0) + self.convf1 = nn.Conv2d(2, 64, 7, padding=3) + self.convf2 = nn.Conv2d(64, 32, 3, padding=1) + self.conv = nn.Conv2d(128, 80, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class BasicMotionEncoder(nn.Module): + def __init__(self, args): + super(BasicMotionEncoder, self).__init__() + cor_planes = args.corr_levels * (2*args.corr_radius + 1)**2 + self.convc1 = nn.Conv2d(cor_planes, 256, 1, padding=0) + self.convc2 = nn.Conv2d(256, 192, 3, padding=1) + self.convf1 = nn.Conv2d(2, 128, 7, padding=3) + self.convf2 = nn.Conv2d(128, 64, 3, padding=1) + self.conv = nn.Conv2d(64+192, 128-2, 3, padding=1) + + def forward(self, flow, corr): + cor = F.relu(self.convc1(corr)) + cor = F.relu(self.convc2(cor)) + flo = F.relu(self.convf1(flow)) + flo = F.relu(self.convf2(flo)) + + cor_flo = torch.cat([cor, flo], dim=1) + out = F.relu(self.conv(cor_flo)) + return torch.cat([out, flow], dim=1) + +class SmallUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=96): + super(SmallUpdateBlock, self).__init__() + self.encoder = SmallMotionEncoder(args) + self.gru = ConvGRU(hidden_dim=hidden_dim, input_dim=82+64) + self.flow_head = FlowHead(hidden_dim, hidden_dim=128) + + def forward(self, net, inp, corr, flow): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + return net, None, delta_flow + +class BasicUpdateBlock(nn.Module): + def __init__(self, args, hidden_dim=128, input_dim=128): + super(BasicUpdateBlock, self).__init__() + self.args = args + self.encoder = BasicMotionEncoder(args) + self.gru = SepConvGRU(hidden_dim=hidden_dim, input_dim=128+hidden_dim) + self.flow_head = FlowHead(hidden_dim, hidden_dim=256) + + self.mask = nn.Sequential( + nn.Conv2d(128, 256, 3, padding=1), + nn.ReLU(inplace=True), + nn.Conv2d(256, 64*9, 1, padding=0)) + + def forward(self, net, inp, corr, flow, upsample=True): + motion_features = self.encoder(flow, corr) + inp = torch.cat([inp, motion_features], dim=1) + + net = self.gru(net, inp) + delta_flow = self.flow_head(net) + + # scale mask to balence gradients + mask = .25 * self.mask(net) + return net, mask, delta_flow + + + diff --git a/propainter/RAFT/utils/__init__.py b/propainter/RAFT/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0437149bfee42718973728158641020ccc1906ad --- /dev/null +++ b/propainter/RAFT/utils/__init__.py @@ -0,0 +1,2 @@ +from .flow_viz import flow_to_image +from .frame_utils import writeFlow diff --git a/propainter/RAFT/utils/augmentor.py b/propainter/RAFT/utils/augmentor.py new file mode 100644 index 0000000000000000000000000000000000000000..e81c4f2b5c16c31c0ae236d744f299d430228a04 --- /dev/null +++ b/propainter/RAFT/utils/augmentor.py @@ -0,0 +1,246 @@ +import numpy as np +import random +import math +from PIL import Image + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +import torch +from torchvision.transforms import ColorJitter +import torch.nn.functional as F + + +class FlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=True): + + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.5/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + """ Photometric augmentation """ + + # asymmetric + if np.random.rand() < self.asymmetric_color_aug_prob: + img1 = np.array(self.photo_aug(Image.fromarray(img1)), dtype=np.uint8) + img2 = np.array(self.photo_aug(Image.fromarray(img2)), dtype=np.uint8) + + # symmetric + else: + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + + return img1, img2 + + def eraser_transform(self, img1, img2, bounds=[50, 100]): + """ Occlusion augmentation """ + + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(bounds[0], bounds[1]) + dy = np.random.randint(bounds[0], bounds[1]) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def spatial_transform(self, img1, img2, flow): + # randomly sample scale + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 8) / float(ht), + (self.crop_size[1] + 8) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = scale + scale_y = scale + if np.random.rand() < self.stretch_prob: + scale_x *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + scale_y *= 2 ** np.random.uniform(-self.max_stretch, self.max_stretch) + + scale_x = np.clip(scale_x, min_scale, None) + scale_y = np.clip(scale_y, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = cv2.resize(flow, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow = flow * [scale_x, scale_y] + + if self.do_flip: + if np.random.rand() < self.h_flip_prob: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + + if np.random.rand() < self.v_flip_prob: # v-flip + img1 = img1[::-1, :] + img2 = img2[::-1, :] + flow = flow[::-1, :] * [1.0, -1.0] + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0]) + x0 = np.random.randint(0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + + return img1, img2, flow + + def __call__(self, img1, img2, flow): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow = self.spatial_transform(img1, img2, flow) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + + return img1, img2, flow + +class SparseFlowAugmentor: + def __init__(self, crop_size, min_scale=-0.2, max_scale=0.5, do_flip=False): + # spatial augmentation params + self.crop_size = crop_size + self.min_scale = min_scale + self.max_scale = max_scale + self.spatial_aug_prob = 0.8 + self.stretch_prob = 0.8 + self.max_stretch = 0.2 + + # flip augmentation params + self.do_flip = do_flip + self.h_flip_prob = 0.5 + self.v_flip_prob = 0.1 + + # photometric augmentation params + self.photo_aug = ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.3/3.14) + self.asymmetric_color_aug_prob = 0.2 + self.eraser_aug_prob = 0.5 + + def color_transform(self, img1, img2): + image_stack = np.concatenate([img1, img2], axis=0) + image_stack = np.array(self.photo_aug(Image.fromarray(image_stack)), dtype=np.uint8) + img1, img2 = np.split(image_stack, 2, axis=0) + return img1, img2 + + def eraser_transform(self, img1, img2): + ht, wd = img1.shape[:2] + if np.random.rand() < self.eraser_aug_prob: + mean_color = np.mean(img2.reshape(-1, 3), axis=0) + for _ in range(np.random.randint(1, 3)): + x0 = np.random.randint(0, wd) + y0 = np.random.randint(0, ht) + dx = np.random.randint(50, 100) + dy = np.random.randint(50, 100) + img2[y0:y0+dy, x0:x0+dx, :] = mean_color + + return img1, img2 + + def resize_sparse_flow_map(self, flow, valid, fx=1.0, fy=1.0): + ht, wd = flow.shape[:2] + coords = np.meshgrid(np.arange(wd), np.arange(ht)) + coords = np.stack(coords, axis=-1) + + coords = coords.reshape(-1, 2).astype(np.float32) + flow = flow.reshape(-1, 2).astype(np.float32) + valid = valid.reshape(-1).astype(np.float32) + + coords0 = coords[valid>=1] + flow0 = flow[valid>=1] + + ht1 = int(round(ht * fy)) + wd1 = int(round(wd * fx)) + + coords1 = coords0 * [fx, fy] + flow1 = flow0 * [fx, fy] + + xx = np.round(coords1[:,0]).astype(np.int32) + yy = np.round(coords1[:,1]).astype(np.int32) + + v = (xx > 0) & (xx < wd1) & (yy > 0) & (yy < ht1) + xx = xx[v] + yy = yy[v] + flow1 = flow1[v] + + flow_img = np.zeros([ht1, wd1, 2], dtype=np.float32) + valid_img = np.zeros([ht1, wd1], dtype=np.int32) + + flow_img[yy, xx] = flow1 + valid_img[yy, xx] = 1 + + return flow_img, valid_img + + def spatial_transform(self, img1, img2, flow, valid): + # randomly sample scale + + ht, wd = img1.shape[:2] + min_scale = np.maximum( + (self.crop_size[0] + 1) / float(ht), + (self.crop_size[1] + 1) / float(wd)) + + scale = 2 ** np.random.uniform(self.min_scale, self.max_scale) + scale_x = np.clip(scale, min_scale, None) + scale_y = np.clip(scale, min_scale, None) + + if np.random.rand() < self.spatial_aug_prob: + # rescale the images + img1 = cv2.resize(img1, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + img2 = cv2.resize(img2, None, fx=scale_x, fy=scale_y, interpolation=cv2.INTER_LINEAR) + flow, valid = self.resize_sparse_flow_map(flow, valid, fx=scale_x, fy=scale_y) + + if self.do_flip: + if np.random.rand() < 0.5: # h-flip + img1 = img1[:, ::-1] + img2 = img2[:, ::-1] + flow = flow[:, ::-1] * [-1.0, 1.0] + valid = valid[:, ::-1] + + margin_y = 20 + margin_x = 50 + + y0 = np.random.randint(0, img1.shape[0] - self.crop_size[0] + margin_y) + x0 = np.random.randint(-margin_x, img1.shape[1] - self.crop_size[1] + margin_x) + + y0 = np.clip(y0, 0, img1.shape[0] - self.crop_size[0]) + x0 = np.clip(x0, 0, img1.shape[1] - self.crop_size[1]) + + img1 = img1[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + img2 = img2[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + flow = flow[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + valid = valid[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1]] + return img1, img2, flow, valid + + + def __call__(self, img1, img2, flow, valid): + img1, img2 = self.color_transform(img1, img2) + img1, img2 = self.eraser_transform(img1, img2) + img1, img2, flow, valid = self.spatial_transform(img1, img2, flow, valid) + + img1 = np.ascontiguousarray(img1) + img2 = np.ascontiguousarray(img2) + flow = np.ascontiguousarray(flow) + valid = np.ascontiguousarray(valid) + + return img1, img2, flow, valid diff --git a/propainter/RAFT/utils/flow_viz.py b/propainter/RAFT/utils/flow_viz.py new file mode 100644 index 0000000000000000000000000000000000000000..dcee65e89b91b07ee0496aeb4c7e7436abf99641 --- /dev/null +++ b/propainter/RAFT/utils/flow_viz.py @@ -0,0 +1,132 @@ +# Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization + + +# MIT License +# +# Copyright (c) 2018 Tom Runia +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to conditions. +# +# Author: Tom Runia +# Date Created: 2018-08-03 + +import numpy as np + +def make_colorwheel(): + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf + + Code follows the original C++ source code of Daniel Scharstein. + Code follows the the Matlab source code of Deqing Sun. + + Returns: + np.ndarray: Color wheel + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = np.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = np.floor(255*np.arange(0,RY)/RY) + col = col+RY + # YG + colorwheel[col:col+YG, 0] = 255 - np.floor(255*np.arange(0,YG)/YG) + colorwheel[col:col+YG, 1] = 255 + col = col+YG + # GC + colorwheel[col:col+GC, 1] = 255 + colorwheel[col:col+GC, 2] = np.floor(255*np.arange(0,GC)/GC) + col = col+GC + # CB + colorwheel[col:col+CB, 1] = 255 - np.floor(255*np.arange(CB)/CB) + colorwheel[col:col+CB, 2] = 255 + col = col+CB + # BM + colorwheel[col:col+BM, 2] = 255 + colorwheel[col:col+BM, 0] = np.floor(255*np.arange(0,BM)/BM) + col = col+BM + # MR + colorwheel[col:col+MR, 2] = 255 - np.floor(255*np.arange(MR)/MR) + colorwheel[col:col+MR, 0] = 255 + return colorwheel + + +def flow_uv_to_colors(u, v, convert_to_bgr=False): + """ + Applies the flow color wheel to (possibly clipped) flow components u and v. + + According to the C++ source code of Daniel Scharstein + According to the Matlab source code of Deqing Sun + + Args: + u (np.ndarray): Input horizontal flow of shape [H,W] + v (np.ndarray): Input vertical flow of shape [H,W] + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) + colorwheel = make_colorwheel() # shape [55x3] + ncols = colorwheel.shape[0] + rad = np.sqrt(np.square(u) + np.square(v)) + a = np.arctan2(-v, -u)/np.pi + fk = (a+1) / 2*(ncols-1) + k0 = np.floor(fk).astype(np.int32) + k1 = k0 + 1 + k1[k1 == ncols] = 0 + f = fk - k0 + for i in range(colorwheel.shape[1]): + tmp = colorwheel[:,i] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1-f)*col0 + f*col1 + idx = (rad <= 1) + col[idx] = 1 - rad[idx] * (1-col[idx]) + col[~idx] = col[~idx] * 0.75 # out of range + # Note the 2-i => BGR instead of RGB + ch_idx = 2-i if convert_to_bgr else i + flow_image[:,:,ch_idx] = np.floor(255 * col) + return flow_image + + +def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): + """ + Expects a two dimensional flow image of shape. + + Args: + flow_uv (np.ndarray): Flow UV image of shape [H,W,2] + clip_flow (float, optional): Clip maximum of flow values. Defaults to None. + convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. + + Returns: + np.ndarray: Flow visualization image of shape [H,W,3] + """ + assert flow_uv.ndim == 3, 'input flow must have three dimensions' + assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' + if clip_flow is not None: + flow_uv = np.clip(flow_uv, 0, clip_flow) + u = flow_uv[:,:,0] + v = flow_uv[:,:,1] + rad = np.sqrt(np.square(u) + np.square(v)) + rad_max = np.max(rad) + epsilon = 1e-5 + u = u / (rad_max + epsilon) + v = v / (rad_max + epsilon) + return flow_uv_to_colors(u, v, convert_to_bgr) \ No newline at end of file diff --git a/propainter/RAFT/utils/flow_viz_pt.py b/propainter/RAFT/utils/flow_viz_pt.py new file mode 100644 index 0000000000000000000000000000000000000000..12e666a40fa49c11592e311b141aa2a522e567fd --- /dev/null +++ b/propainter/RAFT/utils/flow_viz_pt.py @@ -0,0 +1,118 @@ +# Flow visualization code adapted from https://github.com/tomrunia/OpticalFlow_Visualization +import torch +torch.pi = torch.acos(torch.zeros(1)).item() * 2 # which is 3.1415927410125732 + +@torch.no_grad() +def flow_to_image(flow: torch.Tensor) -> torch.Tensor: + + """ + Converts a flow to an RGB image. + + Args: + flow (Tensor): Flow of shape (N, 2, H, W) or (2, H, W) and dtype torch.float. + + Returns: + img (Tensor): Image Tensor of dtype uint8 where each color corresponds + to a given flow direction. Shape is (N, 3, H, W) or (3, H, W) depending on the input. + """ + + if flow.dtype != torch.float: + raise ValueError(f"Flow should be of dtype torch.float, got {flow.dtype}.") + + orig_shape = flow.shape + if flow.ndim == 3: + flow = flow[None] # Add batch dim + + if flow.ndim != 4 or flow.shape[1] != 2: + raise ValueError(f"Input flow should have shape (2, H, W) or (N, 2, H, W), got {orig_shape}.") + + max_norm = torch.sum(flow**2, dim=1).sqrt().max() + epsilon = torch.finfo((flow).dtype).eps + normalized_flow = flow / (max_norm + epsilon) + img = _normalized_flow_to_image(normalized_flow) + + if len(orig_shape) == 3: + img = img[0] # Remove batch dim + return img + +@torch.no_grad() +def _normalized_flow_to_image(normalized_flow: torch.Tensor) -> torch.Tensor: + + """ + Converts a batch of normalized flow to an RGB image. + + Args: + normalized_flow (torch.Tensor): Normalized flow tensor of shape (N, 2, H, W) + Returns: + img (Tensor(N, 3, H, W)): Flow visualization image of dtype uint8. + """ + + N, _, H, W = normalized_flow.shape + device = normalized_flow.device + flow_image = torch.zeros((N, 3, H, W), dtype=torch.uint8, device=device) + colorwheel = _make_colorwheel().to(device) # shape [55x3] + num_cols = colorwheel.shape[0] + norm = torch.sum(normalized_flow**2, dim=1).sqrt() + a = torch.atan2(-normalized_flow[:, 1, :, :], -normalized_flow[:, 0, :, :]) / torch.pi + fk = (a + 1) / 2 * (num_cols - 1) + k0 = torch.floor(fk).to(torch.long) + k1 = k0 + 1 + k1[k1 == num_cols] = 0 + f = fk - k0 + + for c in range(colorwheel.shape[1]): + tmp = colorwheel[:, c] + col0 = tmp[k0] / 255.0 + col1 = tmp[k1] / 255.0 + col = (1 - f) * col0 + f * col1 + col = 1 - norm * (1 - col) + flow_image[:, c, :, :] = torch.floor(255. * col) + return flow_image + + +@torch.no_grad() +def _make_colorwheel() -> torch.Tensor: + """ + Generates a color wheel for optical flow visualization as presented in: + Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) + URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf. + + Returns: + colorwheel (Tensor[55, 3]): Colorwheel Tensor. + """ + + RY = 15 + YG = 6 + GC = 4 + CB = 11 + BM = 13 + MR = 6 + + ncols = RY + YG + GC + CB + BM + MR + colorwheel = torch.zeros((ncols, 3)) + col = 0 + + # RY + colorwheel[0:RY, 0] = 255 + colorwheel[0:RY, 1] = torch.floor(255. * torch.arange(0., RY) / RY) + col = col + RY + # YG + colorwheel[col : col + YG, 0] = 255 - torch.floor(255. * torch.arange(0., YG) / YG) + colorwheel[col : col + YG, 1] = 255 + col = col + YG + # GC + colorwheel[col : col + GC, 1] = 255 + colorwheel[col : col + GC, 2] = torch.floor(255. * torch.arange(0., GC) / GC) + col = col + GC + # CB + colorwheel[col : col + CB, 1] = 255 - torch.floor(255. * torch.arange(CB) / CB) + colorwheel[col : col + CB, 2] = 255 + col = col + CB + # BM + colorwheel[col : col + BM, 2] = 255 + colorwheel[col : col + BM, 0] = torch.floor(255. * torch.arange(0., BM) / BM) + col = col + BM + # MR + colorwheel[col : col + MR, 2] = 255 - torch.floor(255. * torch.arange(MR) / MR) + colorwheel[col : col + MR, 0] = 255 + return colorwheel diff --git a/propainter/RAFT/utils/frame_utils.py b/propainter/RAFT/utils/frame_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6c491135efaffc25bd61ec3ecde99d236f5deb12 --- /dev/null +++ b/propainter/RAFT/utils/frame_utils.py @@ -0,0 +1,137 @@ +import numpy as np +from PIL import Image +from os.path import * +import re + +import cv2 +cv2.setNumThreads(0) +cv2.ocl.setUseOpenCL(False) + +TAG_CHAR = np.array([202021.25], np.float32) + +def readFlow(fn): + """ Read .flo file in Middlebury format""" + # Code adapted from: + # http://stackoverflow.com/questions/28013200/reading-middlebury-flow-files-with-python-bytes-array-numpy + + # WARNING: this will work on little-endian architectures (eg Intel x86) only! + # print 'fn = %s'%(fn) + with open(fn, 'rb') as f: + magic = np.fromfile(f, np.float32, count=1) + if 202021.25 != magic: + print('Magic number incorrect. Invalid .flo file') + return None + else: + w = np.fromfile(f, np.int32, count=1) + h = np.fromfile(f, np.int32, count=1) + # print 'Reading %d x %d flo file\n' % (w, h) + data = np.fromfile(f, np.float32, count=2*int(w)*int(h)) + # Reshape data into 3D array (columns, rows, bands) + # The reshape here is for visualization, the original code is (w,h,2) + return np.resize(data, (int(h), int(w), 2)) + +def readPFM(file): + file = open(file, 'rb') + + color = None + width = None + height = None + scale = None + endian = None + + header = file.readline().rstrip() + if header == b'PF': + color = True + elif header == b'Pf': + color = False + else: + raise Exception('Not a PFM file.') + + dim_match = re.match(rb'^(\d+)\s(\d+)\s$', file.readline()) + if dim_match: + width, height = map(int, dim_match.groups()) + else: + raise Exception('Malformed PFM header.') + + scale = float(file.readline().rstrip()) + if scale < 0: # little-endian + endian = '<' + scale = -scale + else: + endian = '>' # big-endian + + data = np.fromfile(file, endian + 'f') + shape = (height, width, 3) if color else (height, width) + + data = np.reshape(data, shape) + data = np.flipud(data) + return data + +def writeFlow(filename,uv,v=None): + """ Write optical flow to file. + + If v is None, uv is assumed to contain both u and v channels, + stacked in depth. + Original code by Deqing Sun, adapted from Daniel Scharstein. + """ + nBands = 2 + + if v is None: + assert(uv.ndim == 3) + assert(uv.shape[2] == 2) + u = uv[:,:,0] + v = uv[:,:,1] + else: + u = uv + + assert(u.shape == v.shape) + height,width = u.shape + f = open(filename,'wb') + # write the header + f.write(TAG_CHAR) + np.array(width).astype(np.int32).tofile(f) + np.array(height).astype(np.int32).tofile(f) + # arrange into matrix form + tmp = np.zeros((height, width*nBands)) + tmp[:,np.arange(width)*2] = u + tmp[:,np.arange(width)*2 + 1] = v + tmp.astype(np.float32).tofile(f) + f.close() + + +def readFlowKITTI(filename): + flow = cv2.imread(filename, cv2.IMREAD_ANYDEPTH|cv2.IMREAD_COLOR) + flow = flow[:,:,::-1].astype(np.float32) + flow, valid = flow[:, :, :2], flow[:, :, 2] + flow = (flow - 2**15) / 64.0 + return flow, valid + +def readDispKITTI(filename): + disp = cv2.imread(filename, cv2.IMREAD_ANYDEPTH) / 256.0 + valid = disp > 0.0 + flow = np.stack([-disp, np.zeros_like(disp)], -1) + return flow, valid + + +def writeFlowKITTI(filename, uv): + uv = 64.0 * uv + 2**15 + valid = np.ones([uv.shape[0], uv.shape[1], 1]) + uv = np.concatenate([uv, valid], axis=-1).astype(np.uint16) + cv2.imwrite(filename, uv[..., ::-1]) + + +def read_gen(file_name, pil=False): + ext = splitext(file_name)[-1] + if ext == '.png' or ext == '.jpeg' or ext == '.ppm' or ext == '.jpg': + return Image.open(file_name) + elif ext == '.bin' or ext == '.raw': + return np.load(file_name) + elif ext == '.flo': + return readFlow(file_name).astype(np.float32) + elif ext == '.pfm': + flow = readPFM(file_name).astype(np.float32) + if len(flow.shape) == 2: + return flow + else: + return flow[:, :, :-1] + return [] \ No newline at end of file diff --git a/propainter/RAFT/utils/utils.py b/propainter/RAFT/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5f32d281c1c46353a0a2bf36b0550adb74125c65 --- /dev/null +++ b/propainter/RAFT/utils/utils.py @@ -0,0 +1,82 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy import interpolate + + +class InputPadder: + """ Pads images such that dimensions are divisible by 8 """ + def __init__(self, dims, mode='sintel'): + self.ht, self.wd = dims[-2:] + pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 + pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 + if mode == 'sintel': + self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2] + else: + self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht] + + def pad(self, *inputs): + return [F.pad(x, self._pad, mode='replicate') for x in inputs] + + def unpad(self,x): + ht, wd = x.shape[-2:] + c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]] + return x[..., c[0]:c[1], c[2]:c[3]] + +def forward_interpolate(flow): + flow = flow.detach().cpu().numpy() + dx, dy = flow[0], flow[1] + + ht, wd = dx.shape + x0, y0 = np.meshgrid(np.arange(wd), np.arange(ht)) + + x1 = x0 + dx + y1 = y0 + dy + + x1 = x1.reshape(-1) + y1 = y1.reshape(-1) + dx = dx.reshape(-1) + dy = dy.reshape(-1) + + valid = (x1 > 0) & (x1 < wd) & (y1 > 0) & (y1 < ht) + x1 = x1[valid] + y1 = y1[valid] + dx = dx[valid] + dy = dy[valid] + + flow_x = interpolate.griddata( + (x1, y1), dx, (x0, y0), method='nearest', fill_value=0) + + flow_y = interpolate.griddata( + (x1, y1), dy, (x0, y0), method='nearest', fill_value=0) + + flow = np.stack([flow_x, flow_y], axis=0) + return torch.from_numpy(flow).float() + + +def bilinear_sampler(img, coords, mode='bilinear', mask=False): + """ Wrapper for grid_sample, uses pixel coordinates """ + H, W = img.shape[-2:] + xgrid, ygrid = coords.split([1,1], dim=-1) + xgrid = 2*xgrid/(W-1) - 1 + ygrid = 2*ygrid/(H-1) - 1 + + grid = torch.cat([xgrid, ygrid], dim=-1) + img = F.grid_sample(img, grid, align_corners=True) + + if mask: + mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1) + return img, mask.float() + + return img + + +def coords_grid(batch, ht, wd): + coords = torch.meshgrid(torch.arange(ht), torch.arange(wd)) + coords = torch.stack(coords[::-1], dim=0).float() + return coords[None].repeat(batch, 1, 1, 1) + + +def upflow8(flow, mode='bilinear'): + new_size = (8 * flow.shape[2], 8 * flow.shape[3]) + return 8 * F.interpolate(flow, size=new_size, mode=mode, align_corners=True) diff --git a/propainter/core/dataset.py b/propainter/core/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..27b135bb7716f0e89d9a3ec9fd4411dfe3eb94eb --- /dev/null +++ b/propainter/core/dataset.py @@ -0,0 +1,232 @@ +import os +import json +import random + +import cv2 +from PIL import Image +import numpy as np + +import torch +import torchvision.transforms as transforms + +from utils.file_client import FileClient +from utils.img_util import imfrombytes +from utils.flow_util import resize_flow, flowread +from core.utils import (create_random_shape_with_random_motion, Stack, + ToTorchFormatTensor, GroupRandomHorizontalFlip,GroupRandomHorizontalFlowFlip) + + +class TrainDataset(torch.utils.data.Dataset): + def __init__(self, args: dict): + self.args = args + self.video_root = args['video_root'] + self.flow_root = args['flow_root'] + self.num_local_frames = args['num_local_frames'] + self.num_ref_frames = args['num_ref_frames'] + self.size = self.w, self.h = (args['w'], args['h']) + + self.load_flow = args['load_flow'] + if self.load_flow: + assert os.path.exists(self.flow_root) + + json_path = os.path.join('./datasets', args['name'], 'train.json') + + with open(json_path, 'r') as f: + self.video_train_dict = json.load(f) + self.video_names = sorted(list(self.video_train_dict.keys())) + + # self.video_names = sorted(os.listdir(self.video_root)) + self.video_dict = {} + self.frame_dict = {} + + for v in self.video_names: + frame_list = sorted(os.listdir(os.path.join(self.video_root, v))) + v_len = len(frame_list) + if v_len > self.num_local_frames + self.num_ref_frames: + self.video_dict[v] = v_len + self.frame_dict[v] = frame_list + + + self.video_names = list(self.video_dict.keys()) # update names + + self._to_tensors = transforms.Compose([ + Stack(), + ToTorchFormatTensor(), + ]) + self.file_client = FileClient('disk') + + def __len__(self): + return len(self.video_names) + + def _sample_index(self, length, sample_length, num_ref_frame=3): + complete_idx_set = list(range(length)) + pivot = random.randint(0, length - sample_length) + local_idx = complete_idx_set[pivot:pivot + sample_length] + remain_idx = list(set(complete_idx_set) - set(local_idx)) + ref_index = sorted(random.sample(remain_idx, num_ref_frame)) + + return local_idx + ref_index + + def __getitem__(self, index): + video_name = self.video_names[index] + # create masks + all_masks = create_random_shape_with_random_motion( + self.video_dict[video_name], imageHeight=self.h, imageWidth=self.w) + + # create sample index + selected_index = self._sample_index(self.video_dict[video_name], + self.num_local_frames, + self.num_ref_frames) + + # read video frames + frames = [] + masks = [] + flows_f, flows_b = [], [] + for idx in selected_index: + frame_list = self.frame_dict[video_name] + img_path = os.path.join(self.video_root, video_name, frame_list[idx]) + img_bytes = self.file_client.get(img_path, 'img') + img = imfrombytes(img_bytes, float32=False) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR) + img = Image.fromarray(img) + + frames.append(img) + masks.append(all_masks[idx]) + + if len(frames) <= self.num_local_frames-1 and self.load_flow: + current_n = frame_list[idx][:-4] + next_n = frame_list[idx+1][:-4] + flow_f_path = os.path.join(self.flow_root, video_name, f'{current_n}_{next_n}_f.flo') + flow_b_path = os.path.join(self.flow_root, video_name, f'{next_n}_{current_n}_b.flo') + flow_f = flowread(flow_f_path, quantize=False) + flow_b = flowread(flow_b_path, quantize=False) + flow_f = resize_flow(flow_f, self.h, self.w) + flow_b = resize_flow(flow_b, self.h, self.w) + flows_f.append(flow_f) + flows_b.append(flow_b) + + if len(frames) == self.num_local_frames: # random reverse + if random.random() < 0.5: + frames.reverse() + masks.reverse() + if self.load_flow: + flows_f.reverse() + flows_b.reverse() + flows_ = flows_f + flows_f = flows_b + flows_b = flows_ + + if self.load_flow: + frames, flows_f, flows_b = GroupRandomHorizontalFlowFlip()(frames, flows_f, flows_b) + else: + frames = GroupRandomHorizontalFlip()(frames) + + # normalizate, to tensors + frame_tensors = self._to_tensors(frames) * 2.0 - 1.0 + mask_tensors = self._to_tensors(masks) + if self.load_flow: + flows_f = np.stack(flows_f, axis=-1) # H W 2 T-1 + flows_b = np.stack(flows_b, axis=-1) + flows_f = torch.from_numpy(flows_f).permute(3, 2, 0, 1).contiguous().float() + flows_b = torch.from_numpy(flows_b).permute(3, 2, 0, 1).contiguous().float() + + # img [-1,1] mask [0,1] + if self.load_flow: + return frame_tensors, mask_tensors, flows_f, flows_b, video_name + else: + return frame_tensors, mask_tensors, 'None', 'None', video_name + + +class TestDataset(torch.utils.data.Dataset): + def __init__(self, args): + self.args = args + self.size = self.w, self.h = args['size'] + + self.video_root = args['video_root'] + self.mask_root = args['mask_root'] + self.flow_root = args['flow_root'] + + self.load_flow = args['load_flow'] + if self.load_flow: + assert os.path.exists(self.flow_root) + self.video_names = sorted(os.listdir(self.mask_root)) + + self.video_dict = {} + self.frame_dict = {} + + for v in self.video_names: + frame_list = sorted(os.listdir(os.path.join(self.video_root, v))) + v_len = len(frame_list) + self.video_dict[v] = v_len + self.frame_dict[v] = frame_list + + self._to_tensors = transforms.Compose([ + Stack(), + ToTorchFormatTensor(), + ]) + self.file_client = FileClient('disk') + + def __len__(self): + return len(self.video_names) + + def __getitem__(self, index): + video_name = self.video_names[index] + selected_index = list(range(self.video_dict[video_name])) + + # read video frames + frames = [] + masks = [] + flows_f, flows_b = [], [] + for idx in selected_index: + frame_list = self.frame_dict[video_name] + frame_path = os.path.join(self.video_root, video_name, frame_list[idx]) + + img_bytes = self.file_client.get(frame_path, 'input') + img = imfrombytes(img_bytes, float32=False) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = cv2.resize(img, self.size, interpolation=cv2.INTER_LINEAR) + img = Image.fromarray(img) + + frames.append(img) + + mask_path = os.path.join(self.mask_root, video_name, str(idx).zfill(5) + '.png') + mask = Image.open(mask_path).resize(self.size, Image.NEAREST).convert('L') + + # origin: 0 indicates missing. now: 1 indicates missing + mask = np.asarray(mask) + m = np.array(mask > 0).astype(np.uint8) + + m = cv2.dilate(m, + cv2.getStructuringElement(cv2.MORPH_CROSS, (3, 3)), + iterations=4) + mask = Image.fromarray(m * 255) + masks.append(mask) + + if len(frames) <= len(selected_index)-1 and self.load_flow: + current_n = frame_list[idx][:-4] + next_n = frame_list[idx+1][:-4] + flow_f_path = os.path.join(self.flow_root, video_name, f'{current_n}_{next_n}_f.flo') + flow_b_path = os.path.join(self.flow_root, video_name, f'{next_n}_{current_n}_b.flo') + flow_f = flowread(flow_f_path, quantize=False) + flow_b = flowread(flow_b_path, quantize=False) + flow_f = resize_flow(flow_f, self.h, self.w) + flow_b = resize_flow(flow_b, self.h, self.w) + flows_f.append(flow_f) + flows_b.append(flow_b) + + # normalizate, to tensors + frames_PIL = [np.array(f).astype(np.uint8) for f in frames] + frame_tensors = self._to_tensors(frames) * 2.0 - 1.0 + mask_tensors = self._to_tensors(masks) + + if self.load_flow: + flows_f = np.stack(flows_f, axis=-1) # H W 2 T-1 + flows_b = np.stack(flows_b, axis=-1) + flows_f = torch.from_numpy(flows_f).permute(3, 2, 0, 1).contiguous().float() + flows_b = torch.from_numpy(flows_b).permute(3, 2, 0, 1).contiguous().float() + + if self.load_flow: + return frame_tensors, mask_tensors, flows_f, flows_b, video_name, frames_PIL + else: + return frame_tensors, mask_tensors, 'None', 'None', video_name \ No newline at end of file diff --git a/propainter/core/dist.py b/propainter/core/dist.py new file mode 100644 index 0000000000000000000000000000000000000000..4e4e9e670a3b853fac345618d3557d648d813902 --- /dev/null +++ b/propainter/core/dist.py @@ -0,0 +1,47 @@ +import os +import torch + + +def get_world_size(): + """Find OMPI world size without calling mpi functions + :rtype: int + """ + if os.environ.get('PMI_SIZE') is not None: + return int(os.environ.get('PMI_SIZE') or 1) + elif os.environ.get('OMPI_COMM_WORLD_SIZE') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_SIZE') or 1) + else: + return torch.cuda.device_count() + + +def get_global_rank(): + """Find OMPI world rank without calling mpi functions + :rtype: int + """ + if os.environ.get('PMI_RANK') is not None: + return int(os.environ.get('PMI_RANK') or 0) + elif os.environ.get('OMPI_COMM_WORLD_RANK') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_RANK') or 0) + else: + return 0 + + +def get_local_rank(): + """Find OMPI local rank without calling mpi functions + :rtype: int + """ + if os.environ.get('MPI_LOCALRANKID') is not None: + return int(os.environ.get('MPI_LOCALRANKID') or 0) + elif os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') is not None: + return int(os.environ.get('OMPI_COMM_WORLD_LOCAL_RANK') or 0) + else: + return 0 + + +def get_master_ip(): + if os.environ.get('AZ_BATCH_MASTER_NODE') is not None: + return os.environ.get('AZ_BATCH_MASTER_NODE').split(':')[0] + elif os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') is not None: + return os.environ.get('AZ_BATCHAI_MPI_MASTER_NODE') + else: + return "127.0.0.1" diff --git a/propainter/core/loss.py b/propainter/core/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..b1d94d0ce9433b66ce2dce7adb24acb16051e8da --- /dev/null +++ b/propainter/core/loss.py @@ -0,0 +1,180 @@ +import torch +import torch.nn as nn +import lpips +from model.vgg_arch import VGGFeatureExtractor + +class PerceptualLoss(nn.Module): + """Perceptual loss with commonly used style loss. + + Args: + layer_weights (dict): The weight for each layer of vgg feature. + Here is an example: {'conv5_4': 1.}, which means the conv5_4 + feature layer (before relu5_4) will be extracted with weight + 1.0 in calculting losses. + vgg_type (str): The type of vgg network used as feature extractor. + Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image in vgg. + Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + perceptual_weight (float): If `perceptual_weight > 0`, the perceptual + loss will be calculated and the loss will multiplied by the + weight. Default: 1.0. + style_weight (float): If `style_weight > 0`, the style loss will be + calculated and the loss will multiplied by the weight. + Default: 0. + criterion (str): Criterion used for perceptual loss. Default: 'l1'. + """ + + def __init__(self, + layer_weights, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + perceptual_weight=1.0, + style_weight=0., + criterion='l1'): + super(PerceptualLoss, self).__init__() + self.perceptual_weight = perceptual_weight + self.style_weight = style_weight + self.layer_weights = layer_weights + self.vgg = VGGFeatureExtractor( + layer_name_list=list(layer_weights.keys()), + vgg_type=vgg_type, + use_input_norm=use_input_norm, + range_norm=range_norm) + + self.criterion_type = criterion + if self.criterion_type == 'l1': + self.criterion = torch.nn.L1Loss() + elif self.criterion_type == 'l2': + self.criterion = torch.nn.L2loss() + elif self.criterion_type == 'mse': + self.criterion = torch.nn.MSELoss(reduction='mean') + elif self.criterion_type == 'fro': + self.criterion = None + else: + raise NotImplementedError(f'{criterion} criterion has not been supported.') + + def forward(self, x, gt): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + gt (Tensor): Ground-truth tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + # extract vgg features + x_features = self.vgg(x) + gt_features = self.vgg(gt.detach()) + + # calculate perceptual loss + if self.perceptual_weight > 0: + percep_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + percep_loss += torch.norm(x_features[k] - gt_features[k], p='fro') * self.layer_weights[k] + else: + percep_loss += self.criterion(x_features[k], gt_features[k]) * self.layer_weights[k] + percep_loss *= self.perceptual_weight + else: + percep_loss = None + + # calculate style loss + if self.style_weight > 0: + style_loss = 0 + for k in x_features.keys(): + if self.criterion_type == 'fro': + style_loss += torch.norm( + self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k] + else: + style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat( + gt_features[k])) * self.layer_weights[k] + style_loss *= self.style_weight + else: + style_loss = None + + return percep_loss, style_loss + + def _gram_mat(self, x): + """Calculate Gram matrix. + + Args: + x (torch.Tensor): Tensor with shape of (n, c, h, w). + + Returns: + torch.Tensor: Gram matrix. + """ + n, c, h, w = x.size() + features = x.view(n, c, w * h) + features_t = features.transpose(1, 2) + gram = features.bmm(features_t) / (c * h * w) + return gram + +class LPIPSLoss(nn.Module): + def __init__(self, + loss_weight=1.0, + use_input_norm=True, + range_norm=False,): + super(LPIPSLoss, self).__init__() + self.perceptual = lpips.LPIPS(net="vgg", spatial=False).eval() + self.loss_weight = loss_weight + self.use_input_norm = use_input_norm + self.range_norm = range_norm + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [0, 1] + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, pred, target): + if self.range_norm: + pred = (pred + 1) / 2 + target = (target + 1) / 2 + if self.use_input_norm: + pred = (pred - self.mean) / self.std + target = (target - self.mean) / self.std + lpips_loss = self.perceptual(target.contiguous(), pred.contiguous()) + return self.loss_weight * lpips_loss.mean(), None + + +class AdversarialLoss(nn.Module): + r""" + Adversarial loss + https://arxiv.org/abs/1711.10337 + """ + def __init__(self, + type='nsgan', + target_real_label=1.0, + target_fake_label=0.0): + r""" + type = nsgan | lsgan | hinge + """ + super(AdversarialLoss, self).__init__() + self.type = type + self.register_buffer('real_label', torch.tensor(target_real_label)) + self.register_buffer('fake_label', torch.tensor(target_fake_label)) + + if type == 'nsgan': + self.criterion = nn.BCELoss() + elif type == 'lsgan': + self.criterion = nn.MSELoss() + elif type == 'hinge': + self.criterion = nn.ReLU() + + def __call__(self, outputs, is_real, is_disc=None): + if self.type == 'hinge': + if is_disc: + if is_real: + outputs = -outputs + return self.criterion(1 + outputs).mean() + else: + return (-outputs).mean() + else: + labels = (self.real_label + if is_real else self.fake_label).expand_as(outputs) + loss = self.criterion(outputs, labels) + return loss diff --git a/propainter/core/lr_scheduler.py b/propainter/core/lr_scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..1bd1341cdcc64aa1c2a416b837551590ded4a43d --- /dev/null +++ b/propainter/core/lr_scheduler.py @@ -0,0 +1,112 @@ +""" + LR scheduler from BasicSR https://github.com/xinntao/BasicSR +""" +import math +from collections import Counter +from torch.optim.lr_scheduler import _LRScheduler + + +class MultiStepRestartLR(_LRScheduler): + """ MultiStep with restarts learning rate scheme. + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + milestones (list): Iterations that will decrease learning rate. + gamma (float): Decrease ratio. Default: 0.1. + restarts (list): Restart iterations. Default: [0]. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + def __init__(self, + optimizer, + milestones, + gamma=0.1, + restarts=(0, ), + restart_weights=(1, ), + last_epoch=-1): + self.milestones = Counter(milestones) + self.gamma = gamma + self.restarts = restarts + self.restart_weights = restart_weights + assert len(self.restarts) == len( + self.restart_weights), 'restarts and their weights do not match.' + super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + if self.last_epoch in self.restarts: + weight = self.restart_weights[self.restarts.index(self.last_epoch)] + return [ + group['initial_lr'] * weight + for group in self.optimizer.param_groups + ] + if self.last_epoch not in self.milestones: + return [group['lr'] for group in self.optimizer.param_groups] + return [ + group['lr'] * self.gamma**self.milestones[self.last_epoch] + for group in self.optimizer.param_groups + ] + + +def get_position_from_periods(iteration, cumulative_period): + """Get the position from a period list. + It will return the index of the right-closest number in the period list. + For example, the cumulative_period = [100, 200, 300, 400], + if iteration == 50, return 0; + if iteration == 210, return 2; + if iteration == 300, return 2. + Args: + iteration (int): Current iteration. + cumulative_period (list[int]): Cumulative period list. + Returns: + int: The position of the right-closest number in the period list. + """ + for i, period in enumerate(cumulative_period): + if iteration <= period: + return i + + +class CosineAnnealingRestartLR(_LRScheduler): + """ Cosine annealing with restarts learning rate scheme. + An example of config: + periods = [10, 10, 10, 10] + restart_weights = [1, 0.5, 0.5, 0.5] + eta_min=1e-7 + It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the + scheduler will restart with the weights in restart_weights. + Args: + optimizer (torch.nn.optimizer): Torch optimizer. + periods (list): Period for each cosine anneling cycle. + restart_weights (list): Restart weights at each restart iteration. + Default: [1]. + eta_min (float): The mimimum lr. Default: 0. + last_epoch (int): Used in _LRScheduler. Default: -1. + """ + def __init__(self, + optimizer, + periods, + restart_weights=(1, ), + eta_min=1e-7, + last_epoch=-1): + self.periods = periods + self.restart_weights = restart_weights + self.eta_min = eta_min + assert (len(self.periods) == len(self.restart_weights) + ), 'periods and restart_weights should have the same length.' + self.cumulative_period = [ + sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) + ] + super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) + + def get_lr(self): + idx = get_position_from_periods(self.last_epoch, + self.cumulative_period) + current_weight = self.restart_weights[idx] + nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] + current_period = self.periods[idx] + + return [ + self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * + (1 + math.cos(math.pi * ( + (self.last_epoch - nearest_restart) / current_period))) + for base_lr in self.base_lrs + ] diff --git a/propainter/core/metrics.py b/propainter/core/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..d5695c77249936ee1443c5da6ff90b14ee439a00 --- /dev/null +++ b/propainter/core/metrics.py @@ -0,0 +1,571 @@ +import numpy as np +# from skimage import measure +from skimage.metrics import structural_similarity as compare_ssim +from scipy import linalg + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from propainter.core.utils import to_tensors + + +def calculate_epe(flow1, flow2): + """Calculate End point errors.""" + + epe = torch.sum((flow1 - flow2)**2, dim=1).sqrt() + epe = epe.view(-1) + return epe.mean().item() + + +def calculate_psnr(img1, img2): + """Calculate PSNR (Peak Signal-to-Noise Ratio). + Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio + Args: + img1 (ndarray): Images with range [0, 255]. + img2 (ndarray): Images with range [0, 255]. + Returns: + float: psnr result. + """ + + assert img1.shape == img2.shape, \ + (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') + + mse = np.mean((img1 - img2)**2) + if mse == 0: + return float('inf') + return 20. * np.log10(255. / np.sqrt(mse)) + + +def calc_psnr_and_ssim(img1, img2): + """Calculate PSNR and SSIM for images. + img1: ndarray, range [0, 255] + img2: ndarray, range [0, 255] + """ + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + psnr = calculate_psnr(img1, img2) + ssim = compare_ssim(img1, + img2, + data_range=255, + multichannel=True, + win_size=65, + channel_axis=2) + + return psnr, ssim + + +########################### +# I3D models +########################### + + +def init_i3d_model(i3d_model_path): + print(f"[Loading I3D model from {i3d_model_path} for FID score ..]") + i3d_model = InceptionI3d(400, in_channels=3, final_endpoint='Logits') + i3d_model.load_state_dict(torch.load(i3d_model_path)) + i3d_model.to(torch.device('cuda:0')) + return i3d_model + + +def calculate_i3d_activations(video1, video2, i3d_model, device): + """Calculate VFID metric. + video1: list[PIL.Image] + video2: list[PIL.Image] + """ + video1 = to_tensors()(video1).unsqueeze(0).to(device) + video2 = to_tensors()(video2).unsqueeze(0).to(device) + video1_activations = get_i3d_activations( + video1, i3d_model).cpu().numpy().flatten() + video2_activations = get_i3d_activations( + video2, i3d_model).cpu().numpy().flatten() + + return video1_activations, video2_activations + + +def calculate_vfid(real_activations, fake_activations): + """ + Given two distribution of features, compute the FID score between them + Params: + real_activations: list[ndarray] + fake_activations: list[ndarray] + """ + m1 = np.mean(real_activations, axis=0) + m2 = np.mean(fake_activations, axis=0) + s1 = np.cov(real_activations, rowvar=False) + s2 = np.cov(fake_activations, rowvar=False) + return calculate_frechet_distance(m1, s1, m2, s2) + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + Stable version by Dougal J. Sutherland. + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representive data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representive data set. + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert mu1.shape == mu2.shape, \ + 'Training and test mean vectors have different lengths' + assert sigma1.shape == sigma2.shape, \ + 'Training and test covariances have different dimensions' + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ('fid calculation produces singular product; ' + 'adding %s to diagonal of cov estimates') % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError('Imaginary component {}'.format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return (diff.dot(diff) + np.trace(sigma1) + # NOQA + np.trace(sigma2) - 2 * tr_covmean) + + +def get_i3d_activations(batched_video, + i3d_model, + target_endpoint='Logits', + flatten=True, + grad_enabled=False): + """ + Get features from i3d model and flatten them to 1d feature, + valid target endpoints are defined in InceptionI3d.VALID_ENDPOINTS + VALID_ENDPOINTS = ( + 'Conv3d_1a_7x7', + 'MaxPool3d_2a_3x3', + 'Conv3d_2b_1x1', + 'Conv3d_2c_3x3', + 'MaxPool3d_3a_3x3', + 'Mixed_3b', + 'Mixed_3c', + 'MaxPool3d_4a_3x3', + 'Mixed_4b', + 'Mixed_4c', + 'Mixed_4d', + 'Mixed_4e', + 'Mixed_4f', + 'MaxPool3d_5a_2x2', + 'Mixed_5b', + 'Mixed_5c', + 'Logits', + 'Predictions', + ) + """ + with torch.set_grad_enabled(grad_enabled): + feat = i3d_model.extract_features(batched_video.transpose(1, 2), + target_endpoint) + if flatten: + feat = feat.view(feat.size(0), -1) + + return feat + + +# This code is from https://github.com/piergiaj/pytorch-i3d/blob/master/pytorch_i3d.py +# I only fix flake8 errors and do some cleaning here + + +class MaxPool3dSamePadding(nn.MaxPool3d): + def compute_pad(self, dim, s): + if s % self.stride[dim] == 0: + return max(self.kernel_size[dim] - self.stride[dim], 0) + else: + return max(self.kernel_size[dim] - (s % self.stride[dim]), 0) + + def forward(self, x): + # compute 'same' padding + (batch, channel, t, h, w) = x.size() + pad_t = self.compute_pad(0, t) + pad_h = self.compute_pad(1, h) + pad_w = self.compute_pad(2, w) + + pad_t_f = pad_t // 2 + pad_t_b = pad_t - pad_t_f + pad_h_f = pad_h // 2 + pad_h_b = pad_h - pad_h_f + pad_w_f = pad_w // 2 + pad_w_b = pad_w - pad_w_f + + pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) + x = F.pad(x, pad) + return super(MaxPool3dSamePadding, self).forward(x) + + +class Unit3D(nn.Module): + def __init__(self, + in_channels, + output_channels, + kernel_shape=(1, 1, 1), + stride=(1, 1, 1), + padding=0, + activation_fn=F.relu, + use_batch_norm=True, + use_bias=False, + name='unit_3d'): + """Initializes Unit3D module.""" + super(Unit3D, self).__init__() + + self._output_channels = output_channels + self._kernel_shape = kernel_shape + self._stride = stride + self._use_batch_norm = use_batch_norm + self._activation_fn = activation_fn + self._use_bias = use_bias + self.name = name + self.padding = padding + + self.conv3d = nn.Conv3d( + in_channels=in_channels, + out_channels=self._output_channels, + kernel_size=self._kernel_shape, + stride=self._stride, + padding=0, # we always want padding to be 0 here. We will + # dynamically pad based on input size in forward function + bias=self._use_bias) + + if self._use_batch_norm: + self.bn = nn.BatchNorm3d(self._output_channels, + eps=0.001, + momentum=0.01) + + def compute_pad(self, dim, s): + if s % self._stride[dim] == 0: + return max(self._kernel_shape[dim] - self._stride[dim], 0) + else: + return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0) + + def forward(self, x): + # compute 'same' padding + (batch, channel, t, h, w) = x.size() + pad_t = self.compute_pad(0, t) + pad_h = self.compute_pad(1, h) + pad_w = self.compute_pad(2, w) + + pad_t_f = pad_t // 2 + pad_t_b = pad_t - pad_t_f + pad_h_f = pad_h // 2 + pad_h_b = pad_h - pad_h_f + pad_w_f = pad_w // 2 + pad_w_b = pad_w - pad_w_f + + pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b) + x = F.pad(x, pad) + + x = self.conv3d(x) + if self._use_batch_norm: + x = self.bn(x) + if self._activation_fn is not None: + x = self._activation_fn(x) + return x + + +class InceptionModule(nn.Module): + def __init__(self, in_channels, out_channels, name): + super(InceptionModule, self).__init__() + + self.b0 = Unit3D(in_channels=in_channels, + output_channels=out_channels[0], + kernel_shape=[1, 1, 1], + padding=0, + name=name + '/Branch_0/Conv3d_0a_1x1') + self.b1a = Unit3D(in_channels=in_channels, + output_channels=out_channels[1], + kernel_shape=[1, 1, 1], + padding=0, + name=name + '/Branch_1/Conv3d_0a_1x1') + self.b1b = Unit3D(in_channels=out_channels[1], + output_channels=out_channels[2], + kernel_shape=[3, 3, 3], + name=name + '/Branch_1/Conv3d_0b_3x3') + self.b2a = Unit3D(in_channels=in_channels, + output_channels=out_channels[3], + kernel_shape=[1, 1, 1], + padding=0, + name=name + '/Branch_2/Conv3d_0a_1x1') + self.b2b = Unit3D(in_channels=out_channels[3], + output_channels=out_channels[4], + kernel_shape=[3, 3, 3], + name=name + '/Branch_2/Conv3d_0b_3x3') + self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3], + stride=(1, 1, 1), + padding=0) + self.b3b = Unit3D(in_channels=in_channels, + output_channels=out_channels[5], + kernel_shape=[1, 1, 1], + padding=0, + name=name + '/Branch_3/Conv3d_0b_1x1') + self.name = name + + def forward(self, x): + b0 = self.b0(x) + b1 = self.b1b(self.b1a(x)) + b2 = self.b2b(self.b2a(x)) + b3 = self.b3b(self.b3a(x)) + return torch.cat([b0, b1, b2, b3], dim=1) + + +class InceptionI3d(nn.Module): + """Inception-v1 I3D architecture. + The model is introduced in: + Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset + Joao Carreira, Andrew Zisserman + https://arxiv.org/pdf/1705.07750v1.pdf. + See also the Inception architecture, introduced in: + Going deeper with convolutions + Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed, + Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich. + http://arxiv.org/pdf/1409.4842v1.pdf. + """ + + # Endpoints of the model in order. During construction, all the endpoints up + # to a designated `final_endpoint` are returned in a dictionary as the + # second return value. + VALID_ENDPOINTS = ( + 'Conv3d_1a_7x7', + 'MaxPool3d_2a_3x3', + 'Conv3d_2b_1x1', + 'Conv3d_2c_3x3', + 'MaxPool3d_3a_3x3', + 'Mixed_3b', + 'Mixed_3c', + 'MaxPool3d_4a_3x3', + 'Mixed_4b', + 'Mixed_4c', + 'Mixed_4d', + 'Mixed_4e', + 'Mixed_4f', + 'MaxPool3d_5a_2x2', + 'Mixed_5b', + 'Mixed_5c', + 'Logits', + 'Predictions', + ) + + def __init__(self, + num_classes=400, + spatial_squeeze=True, + final_endpoint='Logits', + name='inception_i3d', + in_channels=3, + dropout_keep_prob=0.5): + """Initializes I3D model instance. + Args: + num_classes: The number of outputs in the logit layer (default 400, which + matches the Kinetics dataset). + spatial_squeeze: Whether to squeeze the spatial dimensions for the logits + before returning (default True). + final_endpoint: The model contains many possible endpoints. + `final_endpoint` specifies the last endpoint for the model to be built + up to. In addition to the output at `final_endpoint`, all the outputs + at endpoints up to `final_endpoint` will also be returned, in a + dictionary. `final_endpoint` must be one of + InceptionI3d.VALID_ENDPOINTS (default 'Logits'). + name: A string (optional). The name of this module. + Raises: + ValueError: if `final_endpoint` is not recognized. + """ + + if final_endpoint not in self.VALID_ENDPOINTS: + raise ValueError('Unknown final endpoint %s' % final_endpoint) + + super(InceptionI3d, self).__init__() + self._num_classes = num_classes + self._spatial_squeeze = spatial_squeeze + self._final_endpoint = final_endpoint + self.logits = None + + if self._final_endpoint not in self.VALID_ENDPOINTS: + raise ValueError('Unknown final endpoint %s' % + self._final_endpoint) + + self.end_points = {} + end_point = 'Conv3d_1a_7x7' + self.end_points[end_point] = Unit3D(in_channels=in_channels, + output_channels=64, + kernel_shape=[7, 7, 7], + stride=(2, 2, 2), + padding=(3, 3, 3), + name=name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'MaxPool3d_2a_3x3' + self.end_points[end_point] = MaxPool3dSamePadding( + kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0) + if self._final_endpoint == end_point: + return + + end_point = 'Conv3d_2b_1x1' + self.end_points[end_point] = Unit3D(in_channels=64, + output_channels=64, + kernel_shape=[1, 1, 1], + padding=0, + name=name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'Conv3d_2c_3x3' + self.end_points[end_point] = Unit3D(in_channels=64, + output_channels=192, + kernel_shape=[3, 3, 3], + padding=1, + name=name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'MaxPool3d_3a_3x3' + self.end_points[end_point] = MaxPool3dSamePadding( + kernel_size=[1, 3, 3], stride=(1, 2, 2), padding=0) + if self._final_endpoint == end_point: + return + + end_point = 'Mixed_3b' + self.end_points[end_point] = InceptionModule(192, + [64, 96, 128, 16, 32, 32], + name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'Mixed_3c' + self.end_points[end_point] = InceptionModule( + 256, [128, 128, 192, 32, 96, 64], name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'MaxPool3d_4a_3x3' + self.end_points[end_point] = MaxPool3dSamePadding( + kernel_size=[3, 3, 3], stride=(2, 2, 2), padding=0) + if self._final_endpoint == end_point: + return + + end_point = 'Mixed_4b' + self.end_points[end_point] = InceptionModule( + 128 + 192 + 96 + 64, [192, 96, 208, 16, 48, 64], name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'Mixed_4c' + self.end_points[end_point] = InceptionModule( + 192 + 208 + 48 + 64, [160, 112, 224, 24, 64, 64], name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'Mixed_4d' + self.end_points[end_point] = InceptionModule( + 160 + 224 + 64 + 64, [128, 128, 256, 24, 64, 64], name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'Mixed_4e' + self.end_points[end_point] = InceptionModule( + 128 + 256 + 64 + 64, [112, 144, 288, 32, 64, 64], name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'Mixed_4f' + self.end_points[end_point] = InceptionModule( + 112 + 288 + 64 + 64, [256, 160, 320, 32, 128, 128], + name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'MaxPool3d_5a_2x2' + self.end_points[end_point] = MaxPool3dSamePadding( + kernel_size=[2, 2, 2], stride=(2, 2, 2), padding=0) + if self._final_endpoint == end_point: + return + + end_point = 'Mixed_5b' + self.end_points[end_point] = InceptionModule( + 256 + 320 + 128 + 128, [256, 160, 320, 32, 128, 128], + name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'Mixed_5c' + self.end_points[end_point] = InceptionModule( + 256 + 320 + 128 + 128, [384, 192, 384, 48, 128, 128], + name + end_point) + if self._final_endpoint == end_point: + return + + end_point = 'Logits' + self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7], stride=(1, 1, 1)) + self.dropout = nn.Dropout(dropout_keep_prob) + self.logits = Unit3D(in_channels=384 + 384 + 128 + 128, + output_channels=self._num_classes, + kernel_shape=[1, 1, 1], + padding=0, + activation_fn=None, + use_batch_norm=False, + use_bias=True, + name='logits') + + self.build() + + def replace_logits(self, num_classes): + self._num_classes = num_classes + self.logits = Unit3D(in_channels=384 + 384 + 128 + 128, + output_channels=self._num_classes, + kernel_shape=[1, 1, 1], + padding=0, + activation_fn=None, + use_batch_norm=False, + use_bias=True, + name='logits') + + def build(self): + for k in self.end_points.keys(): + self.add_module(k, self.end_points[k]) + + def forward(self, x): + for end_point in self.VALID_ENDPOINTS: + if end_point in self.end_points: + x = self._modules[end_point]( + x) # use _modules to work with dataparallel + + x = self.logits(self.dropout(self.avg_pool(x))) + if self._spatial_squeeze: + logits = x.squeeze(3).squeeze(3) + # logits is batch X time X classes, which is what we want to work with + return logits + + def extract_features(self, x, target_endpoint='Logits'): + for end_point in self.VALID_ENDPOINTS: + if end_point in self.end_points: + x = self._modules[end_point](x) + if end_point == target_endpoint: + break + if target_endpoint == 'Logits': + return x.mean(4).mean(3).mean(2) + else: + return x diff --git a/propainter/core/prefetch_dataloader.py b/propainter/core/prefetch_dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..5088425050d4cc98114a9b93eb50ea60273f35a0 --- /dev/null +++ b/propainter/core/prefetch_dataloader.py @@ -0,0 +1,125 @@ +import queue as Queue +import threading +import torch +from torch.utils.data import DataLoader + + +class PrefetchGenerator(threading.Thread): + """A general prefetch generator. + + Ref: + https://stackoverflow.com/questions/7323664/python-generator-pre-fetch + + Args: + generator: Python generator. + num_prefetch_queue (int): Number of prefetch queue. + """ + + def __init__(self, generator, num_prefetch_queue): + threading.Thread.__init__(self) + self.queue = Queue.Queue(num_prefetch_queue) + self.generator = generator + self.daemon = True + self.start() + + def run(self): + for item in self.generator: + self.queue.put(item) + self.queue.put(None) + + def __next__(self): + next_item = self.queue.get() + if next_item is None: + raise StopIteration + return next_item + + def __iter__(self): + return self + + +class PrefetchDataLoader(DataLoader): + """Prefetch version of dataloader. + + Ref: + https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# + + TODO: + Need to test on single gpu and ddp (multi-gpu). There is a known issue in + ddp. + + Args: + num_prefetch_queue (int): Number of prefetch queue. + kwargs (dict): Other arguments for dataloader. + """ + + def __init__(self, num_prefetch_queue, **kwargs): + self.num_prefetch_queue = num_prefetch_queue + super(PrefetchDataLoader, self).__init__(**kwargs) + + def __iter__(self): + return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) + + +class CPUPrefetcher(): + """CPU prefetcher. + + Args: + loader: Dataloader. + """ + + def __init__(self, loader): + self.ori_loader = loader + self.loader = iter(loader) + + def next(self): + try: + return next(self.loader) + except StopIteration: + return None + + def reset(self): + self.loader = iter(self.ori_loader) + + +class CUDAPrefetcher(): + """CUDA prefetcher. + + Ref: + https://github.com/NVIDIA/apex/issues/304# + + It may consums more GPU memory. + + Args: + loader: Dataloader. + opt (dict): Options. + """ + + def __init__(self, loader, opt): + self.ori_loader = loader + self.loader = iter(loader) + self.opt = opt + self.stream = torch.cuda.Stream() + self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') + self.preload() + + def preload(self): + try: + self.batch = next(self.loader) # self.batch is a dict + except StopIteration: + self.batch = None + return None + # put tensors to gpu + with torch.cuda.stream(self.stream): + for k, v in self.batch.items(): + if torch.is_tensor(v): + self.batch[k] = self.batch[k].to(device=self.device, non_blocking=True) + + def next(self): + torch.cuda.current_stream().wait_stream(self.stream) + batch = self.batch + self.preload() + return batch + + def reset(self): + self.loader = iter(self.ori_loader) + self.preload() diff --git a/propainter/core/trainer.py b/propainter/core/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..5e6b6a669e92131697fcb12fd548509ce1b81080 --- /dev/null +++ b/propainter/core/trainer.py @@ -0,0 +1,509 @@ +import os +import glob +import logging +import importlib +from tqdm import tqdm + +import torch +import torch.nn as nn +import torch.nn.functional as F +from core.prefetch_dataloader import PrefetchDataLoader, CPUPrefetcher +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP +import torchvision +from torch.utils.tensorboard import SummaryWriter + +from core.lr_scheduler import MultiStepRestartLR, CosineAnnealingRestartLR +from core.loss import AdversarialLoss, PerceptualLoss, LPIPSLoss +from core.dataset import TrainDataset + +from model.modules.flow_comp_raft import RAFT_bi, FlowLoss, EdgeLoss +from model.recurrent_flow_completion import RecurrentFlowCompleteNet + +from RAFT.utils.flow_viz_pt import flow_to_image + + +class Trainer: + def __init__(self, config): + self.config = config + self.epoch = 0 + self.iteration = 0 + self.num_local_frames = config['train_data_loader']['num_local_frames'] + self.num_ref_frames = config['train_data_loader']['num_ref_frames'] + + # setup data set and data loader + self.train_dataset = TrainDataset(config['train_data_loader']) + + self.train_sampler = None + self.train_args = config['trainer'] + if config['distributed']: + self.train_sampler = DistributedSampler( + self.train_dataset, + num_replicas=config['world_size'], + rank=config['global_rank']) + + dataloader_args = dict( + dataset=self.train_dataset, + batch_size=self.train_args['batch_size'] // config['world_size'], + shuffle=(self.train_sampler is None), + num_workers=self.train_args['num_workers'], + sampler=self.train_sampler, + drop_last=True) + + self.train_loader = PrefetchDataLoader(self.train_args['num_prefetch_queue'], **dataloader_args) + self.prefetcher = CPUPrefetcher(self.train_loader) + + # set loss functions + self.adversarial_loss = AdversarialLoss(type=self.config['losses']['GAN_LOSS']) + self.adversarial_loss = self.adversarial_loss.to(self.config['device']) + self.l1_loss = nn.L1Loss() + # self.perc_loss = PerceptualLoss( + # layer_weights={'conv3_4': 0.25, 'conv4_4': 0.25, 'conv5_4': 0.5}, + # use_input_norm=True, + # range_norm=True, + # criterion='l1' + # ).to(self.config['device']) + + if self.config['losses']['perceptual_weight'] > 0: + self.perc_loss = LPIPSLoss(use_input_norm=True, range_norm=True).to(self.config['device']) + + # self.flow_comp_loss = FlowCompletionLoss().to(self.config['device']) + # self.flow_comp_loss = FlowCompletionLoss(self.config['device']) + + # set raft + self.fix_raft = RAFT_bi(device = self.config['device']) + self.fix_flow_complete = RecurrentFlowCompleteNet('weights/recurrent_flow_completion.pth') + for p in self.fix_flow_complete.parameters(): + p.requires_grad = False + self.fix_flow_complete.to(self.config['device']) + self.fix_flow_complete.eval() + + # self.flow_loss = FlowLoss() + + # setup models including generator and discriminator + net = importlib.import_module('model.' + config['model']['net']) + self.netG = net.InpaintGenerator() + # print(self.netG) + self.netG = self.netG.to(self.config['device']) + if not self.config['model'].get('no_dis', False): + if self.config['model'].get('dis_2d', False): + self.netD = net.Discriminator_2D( + in_channels=3, + use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge') + else: + self.netD = net.Discriminator( + in_channels=3, + use_sigmoid=config['losses']['GAN_LOSS'] != 'hinge') + self.netD = self.netD.to(self.config['device']) + + self.interp_mode = self.config['model']['interp_mode'] + # setup optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + self.load() + + if config['distributed']: + self.netG = DDP(self.netG, + device_ids=[self.config['local_rank']], + output_device=self.config['local_rank'], + broadcast_buffers=True, + find_unused_parameters=True) + if not self.config['model']['no_dis']: + self.netD = DDP(self.netD, + device_ids=[self.config['local_rank']], + output_device=self.config['local_rank'], + broadcast_buffers=True, + find_unused_parameters=False) + + # set summary writer + self.dis_writer = None + self.gen_writer = None + self.summary = {} + if self.config['global_rank'] == 0 or (not config['distributed']): + if not self.config['model']['no_dis']: + self.dis_writer = SummaryWriter( + os.path.join(config['save_dir'], 'dis')) + self.gen_writer = SummaryWriter( + os.path.join(config['save_dir'], 'gen')) + + def setup_optimizers(self): + """Set up optimizers.""" + backbone_params = [] + for name, param in self.netG.named_parameters(): + if param.requires_grad: + backbone_params.append(param) + else: + print(f'Params {name} will not be optimized.') + + optim_params = [ + { + 'params': backbone_params, + 'lr': self.config['trainer']['lr'] + }, + ] + + self.optimG = torch.optim.Adam(optim_params, + betas=(self.config['trainer']['beta1'], + self.config['trainer']['beta2'])) + + if not self.config['model']['no_dis']: + self.optimD = torch.optim.Adam( + self.netD.parameters(), + lr=self.config['trainer']['lr'], + betas=(self.config['trainer']['beta1'], + self.config['trainer']['beta2'])) + + def setup_schedulers(self): + """Set up schedulers.""" + scheduler_opt = self.config['trainer']['scheduler'] + scheduler_type = scheduler_opt.pop('type') + + if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: + self.scheG = MultiStepRestartLR( + self.optimG, + milestones=scheduler_opt['milestones'], + gamma=scheduler_opt['gamma']) + if not self.config['model']['no_dis']: + self.scheD = MultiStepRestartLR( + self.optimD, + milestones=scheduler_opt['milestones'], + gamma=scheduler_opt['gamma']) + elif scheduler_type == 'CosineAnnealingRestartLR': + self.scheG = CosineAnnealingRestartLR( + self.optimG, + periods=scheduler_opt['periods'], + restart_weights=scheduler_opt['restart_weights'], + eta_min=scheduler_opt['eta_min']) + if not self.config['model']['no_dis']: + self.scheD = CosineAnnealingRestartLR( + self.optimD, + periods=scheduler_opt['periods'], + restart_weights=scheduler_opt['restart_weights'], + eta_min=scheduler_opt['eta_min']) + else: + raise NotImplementedError( + f'Scheduler {scheduler_type} is not implemented yet.') + + def update_learning_rate(self): + """Update learning rate.""" + self.scheG.step() + if not self.config['model']['no_dis']: + self.scheD.step() + + def get_lr(self): + """Get current learning rate.""" + return self.optimG.param_groups[0]['lr'] + + def add_summary(self, writer, name, val): + """Add tensorboard summary.""" + if name not in self.summary: + self.summary[name] = 0 + self.summary[name] += val + n = self.train_args['log_freq'] + if writer is not None and self.iteration % n == 0: + writer.add_scalar(name, self.summary[name] / n, self.iteration) + self.summary[name] = 0 + + def load(self): + """Load netG (and netD).""" + # get the latest checkpoint + model_path = self.config['save_dir'] + # TODO: add resume name + if os.path.isfile(os.path.join(model_path, 'latest.ckpt')): + latest_epoch = open(os.path.join(model_path, 'latest.ckpt'), + 'r').read().splitlines()[-1] + else: + ckpts = [ + os.path.basename(i).split('.pth')[0] + for i in glob.glob(os.path.join(model_path, '*.pth')) + ] + ckpts.sort() + latest_epoch = ckpts[-1][4:] if len(ckpts) > 0 else None + + if latest_epoch is not None: + gen_path = os.path.join(model_path, + f'gen_{int(latest_epoch):06d}.pth') + dis_path = os.path.join(model_path, + f'dis_{int(latest_epoch):06d}.pth') + opt_path = os.path.join(model_path, + f'opt_{int(latest_epoch):06d}.pth') + + if self.config['global_rank'] == 0: + print(f'Loading model from {gen_path}...') + dataG = torch.load(gen_path, map_location=self.config['device']) + self.netG.load_state_dict(dataG) + if not self.config['model']['no_dis'] and self.config['model']['load_d']: + dataD = torch.load(dis_path, map_location=self.config['device']) + self.netD.load_state_dict(dataD) + + data_opt = torch.load(opt_path, map_location=self.config['device']) + self.optimG.load_state_dict(data_opt['optimG']) + # self.scheG.load_state_dict(data_opt['scheG']) + if not self.config['model']['no_dis'] and self.config['model']['load_d']: + self.optimD.load_state_dict(data_opt['optimD']) + # self.scheD.load_state_dict(data_opt['scheD']) + self.epoch = data_opt['epoch'] + self.iteration = data_opt['iteration'] + else: + gen_path = self.config['trainer'].get('gen_path', None) + dis_path = self.config['trainer'].get('dis_path', None) + opt_path = self.config['trainer'].get('opt_path', None) + if gen_path is not None: + if self.config['global_rank'] == 0: + print(f'Loading Gen-Net from {gen_path}...') + dataG = torch.load(gen_path, map_location=self.config['device']) + self.netG.load_state_dict(dataG) + + if dis_path is not None and not self.config['model']['no_dis'] and self.config['model']['load_d']: + if self.config['global_rank'] == 0: + print(f'Loading Dis-Net from {dis_path}...') + dataD = torch.load(dis_path, map_location=self.config['device']) + self.netD.load_state_dict(dataD) + if opt_path is not None: + data_opt = torch.load(opt_path, map_location=self.config['device']) + self.optimG.load_state_dict(data_opt['optimG']) + self.scheG.load_state_dict(data_opt['scheG']) + if not self.config['model']['no_dis'] and self.config['model']['load_d']: + self.optimD.load_state_dict(data_opt['optimD']) + self.scheD.load_state_dict(data_opt['scheD']) + else: + if self.config['global_rank'] == 0: + print('Warnning: There is no trained model found.' + 'An initialized model will be used.') + + def save(self, it): + """Save parameters every eval_epoch""" + if self.config['global_rank'] == 0: + # configure path + gen_path = os.path.join(self.config['save_dir'], + f'gen_{it:06d}.pth') + dis_path = os.path.join(self.config['save_dir'], + f'dis_{it:06d}.pth') + opt_path = os.path.join(self.config['save_dir'], + f'opt_{it:06d}.pth') + print(f'\nsaving model to {gen_path} ...') + + # remove .module for saving + if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP): + netG = self.netG.module + if not self.config['model']['no_dis']: + netD = self.netD.module + else: + netG = self.netG + if not self.config['model']['no_dis']: + netD = self.netD + + # save checkpoints + torch.save(netG.state_dict(), gen_path) + if not self.config['model']['no_dis']: + torch.save(netD.state_dict(), dis_path) + torch.save( + { + 'epoch': self.epoch, + 'iteration': self.iteration, + 'optimG': self.optimG.state_dict(), + 'optimD': self.optimD.state_dict(), + 'scheG': self.scheG.state_dict(), + 'scheD': self.scheD.state_dict() + }, opt_path) + else: + torch.save( + { + 'epoch': self.epoch, + 'iteration': self.iteration, + 'optimG': self.optimG.state_dict(), + 'scheG': self.scheG.state_dict() + }, opt_path) + + latest_path = os.path.join(self.config['save_dir'], 'latest.ckpt') + os.system(f"echo {it:06d} > {latest_path}") + + def train(self): + """training entry""" + pbar = range(int(self.train_args['iterations'])) + if self.config['global_rank'] == 0: + pbar = tqdm(pbar, + initial=self.iteration, + dynamic_ncols=True, + smoothing=0.01) + + os.makedirs('logs', exist_ok=True) + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(filename)s[line:%(lineno)d]" + "%(levelname)s %(message)s", + datefmt="%a, %d %b %Y %H:%M:%S", + filename=f"logs/{self.config['save_dir'].split('/')[-1]}.log", + filemode='w') + + while True: + self.epoch += 1 + self.prefetcher.reset() + if self.config['distributed']: + self.train_sampler.set_epoch(self.epoch) + self._train_epoch(pbar) + if self.iteration > self.train_args['iterations']: + break + print('\nEnd training....') + + def _train_epoch(self, pbar): + """Process input and calculate loss every training epoch""" + device = self.config['device'] + train_data = self.prefetcher.next() + while train_data is not None: + self.iteration += 1 + frames, masks, flows_f, flows_b, _ = train_data + frames, masks = frames.to(device), masks.to(device).float() + l_t = self.num_local_frames + b, t, c, h, w = frames.size() + gt_local_frames = frames[:, :l_t, ...] + local_masks = masks[:, :l_t, ...].contiguous() + + masked_frames = frames * (1 - masks) + masked_local_frames = masked_frames[:, :l_t, ...] + # get gt optical flow + if flows_f[0] == 'None' or flows_b[0] == 'None': + gt_flows_bi = self.fix_raft(gt_local_frames) + else: + gt_flows_bi = (flows_f.to(device), flows_b.to(device)) + + # ---- complete flow ---- + pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, local_masks) + pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, local_masks) + # pred_flows_bi = gt_flows_bi + + # ---- image propagation ---- + prop_imgs, updated_local_masks = self.netG.module.img_propagation(masked_local_frames, pred_flows_bi, local_masks, interpolation=self.interp_mode) + updated_masks = masks.clone() + updated_masks[:, :l_t, ...] = updated_local_masks.view(b, l_t, 1, h, w) + updated_frames = masked_frames.clone() + prop_local_frames = gt_local_frames * (1-local_masks) + prop_imgs.view(b, l_t, 3, h, w) * local_masks # merge + updated_frames[:, :l_t, ...] = prop_local_frames + + # ---- feature propagation + Transformer ---- + pred_imgs = self.netG(updated_frames, pred_flows_bi, masks, updated_masks, l_t) + pred_imgs = pred_imgs.view(b, -1, c, h, w) + + # get the local frames + pred_local_frames = pred_imgs[:, :l_t, ...] + comp_local_frames = gt_local_frames * (1. - local_masks) + pred_local_frames * local_masks + comp_imgs = frames * (1. - masks) + pred_imgs * masks + + gen_loss = 0 + dis_loss = 0 + # optimize net_g + if not self.config['model']['no_dis']: + for p in self.netD.parameters(): + p.requires_grad = False + + self.optimG.zero_grad() + + # generator l1 loss + hole_loss = self.l1_loss(pred_imgs * masks, frames * masks) + hole_loss = hole_loss / torch.mean(masks) * self.config['losses']['hole_weight'] + gen_loss += hole_loss + self.add_summary(self.gen_writer, 'loss/hole_loss', hole_loss.item()) + + valid_loss = self.l1_loss(pred_imgs * (1 - masks), frames * (1 - masks)) + valid_loss = valid_loss / torch.mean(1-masks) * self.config['losses']['valid_weight'] + gen_loss += valid_loss + self.add_summary(self.gen_writer, 'loss/valid_loss', valid_loss.item()) + + # perceptual loss + if self.config['losses']['perceptual_weight'] > 0: + perc_loss = self.perc_loss(pred_imgs.view(-1,3,h,w), frames.view(-1,3,h,w))[0] * self.config['losses']['perceptual_weight'] + gen_loss += perc_loss + self.add_summary(self.gen_writer, 'loss/perc_loss', perc_loss.item()) + + # gan loss + if not self.config['model']['no_dis']: + # generator adversarial loss + gen_clip = self.netD(comp_imgs) + gan_loss = self.adversarial_loss(gen_clip, True, False) + gan_loss = gan_loss * self.config['losses']['adversarial_weight'] + gen_loss += gan_loss + self.add_summary(self.gen_writer, 'loss/gan_loss', gan_loss.item()) + gen_loss.backward() + self.optimG.step() + + if not self.config['model']['no_dis']: + # optimize net_d + for p in self.netD.parameters(): + p.requires_grad = True + self.optimD.zero_grad() + + # discriminator adversarial loss + real_clip = self.netD(frames) + fake_clip = self.netD(comp_imgs.detach()) + dis_real_loss = self.adversarial_loss(real_clip, True, True) + dis_fake_loss = self.adversarial_loss(fake_clip, False, True) + dis_loss += (dis_real_loss + dis_fake_loss) / 2 + self.add_summary(self.dis_writer, 'loss/dis_vid_real', dis_real_loss.item()) + self.add_summary(self.dis_writer, 'loss/dis_vid_fake', dis_fake_loss.item()) + dis_loss.backward() + self.optimD.step() + + self.update_learning_rate() + + # write image to tensorboard + if self.iteration % 200 == 0: + # img to cpu + t = 0 + gt_local_frames_cpu = ((gt_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu() + masked_local_frames = ((masked_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu() + prop_local_frames_cpu = ((prop_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu() + pred_local_frames_cpu = ((pred_local_frames.view(b,-1,3,h,w) + 1)/2.0).cpu() + img_results = torch.cat([masked_local_frames[0][t], gt_local_frames_cpu[0][t], + prop_local_frames_cpu[0][t], pred_local_frames_cpu[0][t]], 1) + img_results = torchvision.utils.make_grid(img_results, nrow=1, normalize=True) + if self.gen_writer is not None: + self.gen_writer.add_image(f'img/img:inp-gt-res-{t}', img_results, self.iteration) + + t = 5 + if masked_local_frames.shape[1] > 5: + img_results = torch.cat([masked_local_frames[0][t], gt_local_frames_cpu[0][t], + prop_local_frames_cpu[0][t], pred_local_frames_cpu[0][t]], 1) + img_results = torchvision.utils.make_grid(img_results, nrow=1, normalize=True) + if self.gen_writer is not None: + self.gen_writer.add_image(f'img/img:inp-gt-res-{t}', img_results, self.iteration) + + # flow to cpu + gt_flows_forward_cpu = flow_to_image(gt_flows_bi[0][0]).cpu() + masked_flows_forward_cpu = (gt_flows_forward_cpu[0] * (1-local_masks[0][0].cpu())).to(gt_flows_forward_cpu) + pred_flows_forward_cpu = flow_to_image(pred_flows_bi[0][0]).cpu() + + flow_results = torch.cat([gt_flows_forward_cpu[0], masked_flows_forward_cpu, pred_flows_forward_cpu[0]], 1) + if self.gen_writer is not None: + self.gen_writer.add_image('img/flow:gt-pred', flow_results, self.iteration) + + # console logs + if self.config['global_rank'] == 0: + pbar.update(1) + if not self.config['model']['no_dis']: + pbar.set_description((f"d: {dis_loss.item():.3f}; " + f"hole: {hole_loss.item():.3f}; " + f"valid: {valid_loss.item():.3f}")) + else: + pbar.set_description((f"hole: {hole_loss.item():.3f}; " + f"valid: {valid_loss.item():.3f}")) + + if self.iteration % self.train_args['log_freq'] == 0: + if not self.config['model']['no_dis']: + logging.info(f"[Iter {self.iteration}] " + f"d: {dis_loss.item():.4f}; " + f"hole: {hole_loss.item():.4f}; " + f"valid: {valid_loss.item():.4f}") + else: + logging.info(f"[Iter {self.iteration}] " + f"hole: {hole_loss.item():.4f}; " + f"valid: {valid_loss.item():.4f}") + + # saving models + if self.iteration % self.train_args['save_freq'] == 0: + self.save(int(self.iteration)) + + if self.iteration > self.train_args['iterations']: + break + + train_data = self.prefetcher.next() \ No newline at end of file diff --git a/propainter/core/trainer_flow_w_edge.py b/propainter/core/trainer_flow_w_edge.py new file mode 100644 index 0000000000000000000000000000000000000000..d4eba04c8a5fa56bce3e335e6036bc0e0a1e848a --- /dev/null +++ b/propainter/core/trainer_flow_w_edge.py @@ -0,0 +1,380 @@ +import os +import glob +import logging +import importlib +from tqdm import tqdm + +import torch +import torch.nn as nn +import torch.nn.functional as F +from core.prefetch_dataloader import PrefetchDataLoader, CPUPrefetcher +from torch.utils.data.distributed import DistributedSampler +from torch.nn.parallel import DistributedDataParallel as DDP + +from torch.utils.tensorboard import SummaryWriter + +from core.lr_scheduler import MultiStepRestartLR, CosineAnnealingRestartLR +from core.dataset import TrainDataset + +from model.modules.flow_comp_raft import RAFT_bi, FlowLoss, EdgeLoss + +# from skimage.feature import canny +from model.canny.canny_filter import Canny +from RAFT.utils.flow_viz_pt import flow_to_image + + +class Trainer: + def __init__(self, config): + self.config = config + self.epoch = 0 + self.iteration = 0 + self.num_local_frames = config['train_data_loader']['num_local_frames'] + self.num_ref_frames = config['train_data_loader']['num_ref_frames'] + + # setup data set and data loader + self.train_dataset = TrainDataset(config['train_data_loader']) + + self.train_sampler = None + self.train_args = config['trainer'] + if config['distributed']: + self.train_sampler = DistributedSampler( + self.train_dataset, + num_replicas=config['world_size'], + rank=config['global_rank']) + + dataloader_args = dict( + dataset=self.train_dataset, + batch_size=self.train_args['batch_size'] // config['world_size'], + shuffle=(self.train_sampler is None), + num_workers=self.train_args['num_workers'], + sampler=self.train_sampler, + drop_last=True) + + self.train_loader = PrefetchDataLoader(self.train_args['num_prefetch_queue'], **dataloader_args) + self.prefetcher = CPUPrefetcher(self.train_loader) + + # set raft + self.fix_raft = RAFT_bi(device = self.config['device']) + self.flow_loss = FlowLoss() + self.edge_loss = EdgeLoss() + self.canny = Canny(sigma=(2,2), low_threshold=0.1, high_threshold=0.2) + + # setup models including generator and discriminator + net = importlib.import_module('model.' + config['model']['net']) + self.netG = net.RecurrentFlowCompleteNet() + # print(self.netG) + self.netG = self.netG.to(self.config['device']) + + # setup optimizers and schedulers + self.setup_optimizers() + self.setup_schedulers() + self.load() + + if config['distributed']: + self.netG = DDP(self.netG, + device_ids=[self.config['local_rank']], + output_device=self.config['local_rank'], + broadcast_buffers=True, + find_unused_parameters=True) + + # set summary writer + self.dis_writer = None + self.gen_writer = None + self.summary = {} + if self.config['global_rank'] == 0 or (not config['distributed']): + self.gen_writer = SummaryWriter( + os.path.join(config['save_dir'], 'gen')) + + def setup_optimizers(self): + """Set up optimizers.""" + backbone_params = [] + for name, param in self.netG.named_parameters(): + if param.requires_grad: + backbone_params.append(param) + else: + print(f'Params {name} will not be optimized.') + + optim_params = [ + { + 'params': backbone_params, + 'lr': self.config['trainer']['lr'] + }, + ] + + self.optimG = torch.optim.Adam(optim_params, + betas=(self.config['trainer']['beta1'], + self.config['trainer']['beta2'])) + + + def setup_schedulers(self): + """Set up schedulers.""" + scheduler_opt = self.config['trainer']['scheduler'] + scheduler_type = scheduler_opt.pop('type') + + if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: + self.scheG = MultiStepRestartLR( + self.optimG, + milestones=scheduler_opt['milestones'], + gamma=scheduler_opt['gamma']) + elif scheduler_type == 'CosineAnnealingRestartLR': + self.scheG = CosineAnnealingRestartLR( + self.optimG, + periods=scheduler_opt['periods'], + restart_weights=scheduler_opt['restart_weights']) + else: + raise NotImplementedError( + f'Scheduler {scheduler_type} is not implemented yet.') + + def update_learning_rate(self): + """Update learning rate.""" + self.scheG.step() + + def get_lr(self): + """Get current learning rate.""" + return self.optimG.param_groups[0]['lr'] + + def add_summary(self, writer, name, val): + """Add tensorboard summary.""" + if name not in self.summary: + self.summary[name] = 0 + self.summary[name] += val + n = self.train_args['log_freq'] + if writer is not None and self.iteration % n == 0: + writer.add_scalar(name, self.summary[name] / n, self.iteration) + self.summary[name] = 0 + + def load(self): + """Load netG.""" + # get the latest checkpoint + model_path = self.config['save_dir'] + if os.path.isfile(os.path.join(model_path, 'latest.ckpt')): + latest_epoch = open(os.path.join(model_path, 'latest.ckpt'), + 'r').read().splitlines()[-1] + else: + ckpts = [ + os.path.basename(i).split('.pth')[0] + for i in glob.glob(os.path.join(model_path, '*.pth')) + ] + ckpts.sort() + latest_epoch = ckpts[-1][4:] if len(ckpts) > 0 else None + + if latest_epoch is not None: + gen_path = os.path.join(model_path, f'gen_{int(latest_epoch):06d}.pth') + opt_path = os.path.join(model_path,f'opt_{int(latest_epoch):06d}.pth') + + if self.config['global_rank'] == 0: + print(f'Loading model from {gen_path}...') + dataG = torch.load(gen_path, map_location=self.config['device']) + self.netG.load_state_dict(dataG) + + + data_opt = torch.load(opt_path, map_location=self.config['device']) + self.optimG.load_state_dict(data_opt['optimG']) + self.scheG.load_state_dict(data_opt['scheG']) + + self.epoch = data_opt['epoch'] + self.iteration = data_opt['iteration'] + + else: + if self.config['global_rank'] == 0: + print('Warnning: There is no trained model found.' + 'An initialized model will be used.') + + def save(self, it): + """Save parameters every eval_epoch""" + if self.config['global_rank'] == 0: + # configure path + gen_path = os.path.join(self.config['save_dir'], + f'gen_{it:06d}.pth') + opt_path = os.path.join(self.config['save_dir'], + f'opt_{it:06d}.pth') + print(f'\nsaving model to {gen_path} ...') + + # remove .module for saving + if isinstance(self.netG, torch.nn.DataParallel) or isinstance(self.netG, DDP): + netG = self.netG.module + else: + netG = self.netG + + # save checkpoints + torch.save(netG.state_dict(), gen_path) + torch.save( + { + 'epoch': self.epoch, + 'iteration': self.iteration, + 'optimG': self.optimG.state_dict(), + 'scheG': self.scheG.state_dict() + }, opt_path) + + latest_path = os.path.join(self.config['save_dir'], 'latest.ckpt') + os.system(f"echo {it:06d} > {latest_path}") + + def train(self): + """training entry""" + pbar = range(int(self.train_args['iterations'])) + if self.config['global_rank'] == 0: + pbar = tqdm(pbar, + initial=self.iteration, + dynamic_ncols=True, + smoothing=0.01) + + os.makedirs('logs', exist_ok=True) + + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(filename)s[line:%(lineno)d]" + "%(levelname)s %(message)s", + datefmt="%a, %d %b %Y %H:%M:%S", + filename=f"logs/{self.config['save_dir'].split('/')[-1]}.log", + filemode='w') + + while True: + self.epoch += 1 + self.prefetcher.reset() + if self.config['distributed']: + self.train_sampler.set_epoch(self.epoch) + self._train_epoch(pbar) + if self.iteration > self.train_args['iterations']: + break + print('\nEnd training....') + + # def get_edges(self, flows): # fgvc + # # (b, t, 2, H, W) + # b, t, _, h, w = flows.shape + # flows = flows.view(-1, 2, h, w) + # flows_list = flows.permute(0, 2, 3, 1).cpu().numpy() + # edges = [] + # for f in list(flows_list): + # flows_gray = (f[:, :, 0] ** 2 + f[:, :, 1] ** 2) ** 0.5 + # if flows_gray.max() < 1: + # flows_gray = flows_gray*0 + # else: + # flows_gray = flows_gray / flows_gray.max() + + # edge = canny(flows_gray, sigma=2, low_threshold=0.1, high_threshold=0.2) # fgvc + # edge = torch.from_numpy(edge).view(1, 1, h, w).float() + # edges.append(edge) + # edges = torch.stack(edges, dim=0).to(self.config['device']) + # edges = edges.view(b, t, 1, h, w) + # return edges + + def get_edges(self, flows): + # (b, t, 2, H, W) + b, t, _, h, w = flows.shape + flows = flows.view(-1, 2, h, w) + flows_gray = (flows[:, 0, None] ** 2 + flows[:, 1, None] ** 2) ** 0.5 + if flows_gray.max() < 1: + flows_gray = flows_gray*0 + else: + flows_gray = flows_gray / flows_gray.max() + + magnitude, edges = self.canny(flows_gray.float()) + edges = edges.view(b, t, 1, h, w) + return edges + + def _train_epoch(self, pbar): + """Process input and calculate loss every training epoch""" + device = self.config['device'] + train_data = self.prefetcher.next() + while train_data is not None: + self.iteration += 1 + frames, masks, flows_f, flows_b, _ = train_data + frames, masks = frames.to(device), masks.to(device) + masks = masks.float() + + l_t = self.num_local_frames + b, t, c, h, w = frames.size() + gt_local_frames = frames[:, :l_t, ...] + local_masks = masks[:, :l_t, ...].contiguous() + + # get gt optical flow + if flows_f[0] == 'None' or flows_b[0] == 'None': + gt_flows_bi = self.fix_raft(gt_local_frames) + else: + gt_flows_bi = (flows_f.to(device), flows_b.to(device)) + + # get gt edge + gt_edges_forward = self.get_edges(gt_flows_bi[0]) + gt_edges_backward = self.get_edges(gt_flows_bi[1]) + gt_edges_bi = [gt_edges_forward, gt_edges_backward] + + # complete flow + pred_flows_bi, pred_edges_bi = self.netG.module.forward_bidirect_flow(gt_flows_bi, local_masks) + + # optimize net_g + self.optimG.zero_grad() + + # compulte flow_loss + flow_loss, warp_loss = self.flow_loss(pred_flows_bi, gt_flows_bi, local_masks, gt_local_frames) + flow_loss = flow_loss * self.config['losses']['flow_weight'] + warp_loss = warp_loss * 0.01 + self.add_summary(self.gen_writer, 'loss/flow_loss', flow_loss.item()) + self.add_summary(self.gen_writer, 'loss/warp_loss', warp_loss.item()) + + # compute edge loss + edge_loss = self.edge_loss(pred_edges_bi, gt_edges_bi, local_masks) + edge_loss = edge_loss*1.0 + self.add_summary(self.gen_writer, 'loss/edge_loss', edge_loss.item()) + + loss = flow_loss + warp_loss + edge_loss + loss.backward() + self.optimG.step() + self.update_learning_rate() + + # write image to tensorboard + # if self.iteration % 200 == 0: + if self.iteration % 200 == 0 and self.gen_writer is not None: + t = 5 + # forward to cpu + gt_flows_forward_cpu = flow_to_image(gt_flows_bi[0][0]).cpu() + masked_flows_forward_cpu = (gt_flows_forward_cpu[t] * (1-local_masks[0][t].cpu())).to(gt_flows_forward_cpu) + pred_flows_forward_cpu = flow_to_image(pred_flows_bi[0][0]).cpu() + + flow_results = torch.cat([gt_flows_forward_cpu[t], masked_flows_forward_cpu, pred_flows_forward_cpu[t]], 1) + self.gen_writer.add_image('img/flow-f:gt-pred', flow_results, self.iteration) + + # backward to cpu + gt_flows_backward_cpu = flow_to_image(gt_flows_bi[1][0]).cpu() + masked_flows_backward_cpu = (gt_flows_backward_cpu[t] * (1-local_masks[0][t+1].cpu())).to(gt_flows_backward_cpu) + pred_flows_backward_cpu = flow_to_image(pred_flows_bi[1][0]).cpu() + + flow_results = torch.cat([gt_flows_backward_cpu[t], masked_flows_backward_cpu, pred_flows_backward_cpu[t]], 1) + self.gen_writer.add_image('img/flow-b:gt-pred', flow_results, self.iteration) + + # TODO: show edge + # forward + gt_edges_forward_cpu = gt_edges_bi[0][0].cpu() + masked_edges_forward_cpu = (gt_edges_forward_cpu[t] * (1-local_masks[0][t].cpu())).to(gt_edges_forward_cpu) + pred_edges_forward_cpu = pred_edges_bi[0][0].cpu() + + edge_results = torch.cat([gt_edges_forward_cpu[t], masked_edges_forward_cpu, pred_edges_forward_cpu[t]], 1) + self.gen_writer.add_image('img/edge-f:gt-pred', edge_results, self.iteration) + # backward + gt_edges_backward_cpu = gt_edges_bi[1][0].cpu() + masked_edges_backward_cpu = (gt_edges_backward_cpu[t] * (1-local_masks[0][t+1].cpu())).to(gt_edges_backward_cpu) + pred_edges_backward_cpu = pred_edges_bi[1][0].cpu() + + edge_results = torch.cat([gt_edges_backward_cpu[t], masked_edges_backward_cpu, pred_edges_backward_cpu[t]], 1) + self.gen_writer.add_image('img/edge-b:gt-pred', edge_results, self.iteration) + + # console logs + if self.config['global_rank'] == 0: + pbar.update(1) + pbar.set_description((f"flow: {flow_loss.item():.3f}; " + f"warp: {warp_loss.item():.3f}; " + f"edge: {edge_loss.item():.3f}; " + f"lr: {self.get_lr()}")) + + if self.iteration % self.train_args['log_freq'] == 0: + logging.info(f"[Iter {self.iteration}] " + f"flow: {flow_loss.item():.4f}; " + f"warp: {warp_loss.item():.4f}") + + # saving models + if self.iteration % self.train_args['save_freq'] == 0: + self.save(int(self.iteration)) + + if self.iteration > self.train_args['iterations']: + break + + train_data = self.prefetcher.next() \ No newline at end of file diff --git a/propainter/core/utils.py b/propainter/core/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..37dccb2d26e6916aacbd530ab03726a7c54f8ec8 --- /dev/null +++ b/propainter/core/utils.py @@ -0,0 +1,371 @@ +import os +import io +import cv2 +import random +import numpy as np +from PIL import Image, ImageOps +import zipfile +import math + +import torch +import matplotlib +import matplotlib.patches as patches +from matplotlib.path import Path +from matplotlib import pyplot as plt +from torchvision import transforms + +# matplotlib.use('agg') + +# ########################################################################### +# Directory IO +# ########################################################################### + + +def read_dirnames_under_root(root_dir): + dirnames = [ + name for i, name in enumerate(sorted(os.listdir(root_dir))) + if os.path.isdir(os.path.join(root_dir, name)) + ] + print(f'Reading directories under {root_dir}, num: {len(dirnames)}') + return dirnames + + +class TrainZipReader(object): + file_dict = dict() + + def __init__(self): + super(TrainZipReader, self).__init__() + + @staticmethod + def build_file_dict(path): + file_dict = TrainZipReader.file_dict + if path in file_dict: + return file_dict[path] + else: + file_handle = zipfile.ZipFile(path, 'r') + file_dict[path] = file_handle + return file_dict[path] + + @staticmethod + def imread(path, idx): + zfile = TrainZipReader.build_file_dict(path) + filelist = zfile.namelist() + filelist.sort() + data = zfile.read(filelist[idx]) + # + im = Image.open(io.BytesIO(data)) + return im + + +class TestZipReader(object): + file_dict = dict() + + def __init__(self): + super(TestZipReader, self).__init__() + + @staticmethod + def build_file_dict(path): + file_dict = TestZipReader.file_dict + if path in file_dict: + return file_dict[path] + else: + file_handle = zipfile.ZipFile(path, 'r') + file_dict[path] = file_handle + return file_dict[path] + + @staticmethod + def imread(path, idx): + zfile = TestZipReader.build_file_dict(path) + filelist = zfile.namelist() + filelist.sort() + data = zfile.read(filelist[idx]) + file_bytes = np.asarray(bytearray(data), dtype=np.uint8) + im = cv2.imdecode(file_bytes, cv2.IMREAD_COLOR) + im = Image.fromarray(cv2.cvtColor(im, cv2.COLOR_BGR2RGB)) + # im = Image.open(io.BytesIO(data)) + return im + + +# ########################################################################### +# Data augmentation +# ########################################################################### + + +def to_tensors(): + return transforms.Compose([Stack(), ToTorchFormatTensor()]) + + +class GroupRandomHorizontalFlowFlip(object): + """Randomly horizontally flips the given PIL.Image with a probability of 0.5 + """ + def __call__(self, img_group, flowF_group, flowB_group): + v = random.random() + if v < 0.5: + ret_img = [ + img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group + ] + ret_flowF = [ff[:, ::-1] * [-1.0, 1.0] for ff in flowF_group] + ret_flowB = [fb[:, ::-1] * [-1.0, 1.0] for fb in flowB_group] + return ret_img, ret_flowF, ret_flowB + else: + return img_group, flowF_group, flowB_group + + +class GroupRandomHorizontalFlip(object): + """Randomly horizontally flips the given PIL.Image with a probability of 0.5 + """ + def __call__(self, img_group, is_flow=False): + v = random.random() + if v < 0.5: + ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] + if is_flow: + for i in range(0, len(ret), 2): + # invert flow pixel values when flipping + ret[i] = ImageOps.invert(ret[i]) + return ret + else: + return img_group + + +class Stack(object): + def __init__(self, roll=False): + self.roll = roll + + def __call__(self, img_group): + mode = img_group[0].mode + if mode == '1': + img_group = [img.convert('L') for img in img_group] + mode = 'L' + if mode == 'L': + return np.stack([np.expand_dims(x, 2) for x in img_group], axis=2) + elif mode == 'RGB': + if self.roll: + return np.stack([np.array(x)[:, :, ::-1] for x in img_group], + axis=2) + else: + return np.stack(img_group, axis=2) + else: + raise NotImplementedError(f"Image mode {mode}") + + +class ToTorchFormatTensor(object): + """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] + to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ + def __init__(self, div=True): + self.div = div + + def __call__(self, pic): + if isinstance(pic, np.ndarray): + # numpy img: [L, C, H, W] + img = torch.from_numpy(pic).permute(2, 3, 0, 1).contiguous() + else: + # handle PIL Image + img = torch.ByteTensor(torch.ByteStorage.from_buffer( + pic.tobytes())) + img = img.view(pic.size[1], pic.size[0], len(pic.mode)) + # put it from HWC to CHW format + # yikes, this transpose takes 80% of the loading time/CPU + img = img.transpose(0, 1).transpose(0, 2).contiguous() + img = img.float().div(255) if self.div else img.float() + return img + + +# ########################################################################### +# Create masks with random shape +# ########################################################################### + + +def create_random_shape_with_random_motion(video_length, + imageHeight=240, + imageWidth=432): + # get a random shape + height = random.randint(imageHeight // 3, imageHeight - 1) + width = random.randint(imageWidth // 3, imageWidth - 1) + edge_num = random.randint(6, 8) + ratio = random.randint(6, 8) / 10 + + region = get_random_shape(edge_num=edge_num, + ratio=ratio, + height=height, + width=width) + region_width, region_height = region.size + # get random position + x, y = random.randint(0, imageHeight - region_height), random.randint( + 0, imageWidth - region_width) + velocity = get_random_velocity(max_speed=3) + m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8)) + m.paste(region, (y, x, y + region.size[0], x + region.size[1])) + masks = [m.convert('L')] + # return fixed masks + if random.uniform(0, 1) > 0.5: + return masks * video_length + # return moving masks + for _ in range(video_length - 1): + x, y, velocity = random_move_control_points(x, + y, + imageHeight, + imageWidth, + velocity, + region.size, + maxLineAcceleration=(3, + 0.5), + maxInitSpeed=3) + m = Image.fromarray( + np.zeros((imageHeight, imageWidth)).astype(np.uint8)) + m.paste(region, (y, x, y + region.size[0], x + region.size[1])) + masks.append(m.convert('L')) + return masks + + +def create_random_shape_with_random_motion_zoom_rotation(video_length, zoomin=0.9, zoomout=1.1, rotmin=1, rotmax=10, imageHeight=240, imageWidth=432): + # get a random shape + assert zoomin < 1, "Zoom-in parameter must be smaller than 1" + assert zoomout > 1, "Zoom-out parameter must be larger than 1" + assert rotmin < rotmax, "Minimum value of rotation must be smaller than maximun value !" + height = random.randint(imageHeight//3, imageHeight-1) + width = random.randint(imageWidth//3, imageWidth-1) + edge_num = random.randint(6, 8) + ratio = random.randint(6, 8)/10 + region = get_random_shape( + edge_num=edge_num, ratio=ratio, height=height, width=width) + region_width, region_height = region.size + # get random position + x, y = random.randint( + 0, imageHeight-region_height), random.randint(0, imageWidth-region_width) + velocity = get_random_velocity(max_speed=3) + m = Image.fromarray(np.zeros((imageHeight, imageWidth)).astype(np.uint8)) + m.paste(region, (y, x, y+region.size[0], x+region.size[1])) + masks = [m.convert('L')] + # return fixed masks + if random.uniform(0, 1) > 0.5: + return masks*video_length # -> directly copy all the base masks + # return moving masks + for _ in range(video_length-1): + x, y, velocity = random_move_control_points( + x, y, imageHeight, imageWidth, velocity, region.size, maxLineAcceleration=(3, 0.5), maxInitSpeed=3) + m = Image.fromarray( + np.zeros((imageHeight, imageWidth)).astype(np.uint8)) + ### add by kaidong, to simulate zoon-in, zoom-out and rotation + extra_transform = random.uniform(0, 1) + # zoom in and zoom out + if extra_transform > 0.75: + resize_coefficient = random.uniform(zoomin, zoomout) + region = region.resize((math.ceil(region_width * resize_coefficient), math.ceil(region_height * resize_coefficient)), Image.NEAREST) + m.paste(region, (y, x, y + region.size[0], x + region.size[1])) + region_width, region_height = region.size + # rotation + elif extra_transform > 0.5: + m.paste(region, (y, x, y + region.size[0], x + region.size[1])) + m = m.rotate(random.randint(rotmin, rotmax)) + # region_width, region_height = region.size + ### end + else: + m.paste(region, (y, x, y+region.size[0], x+region.size[1])) + masks.append(m.convert('L')) + return masks + + +def get_random_shape(edge_num=9, ratio=0.7, width=432, height=240): + ''' + There is the initial point and 3 points per cubic bezier curve. + Thus, the curve will only pass though n points, which will be the sharp edges. + The other 2 modify the shape of the bezier curve. + edge_num, Number of possibly sharp edges + points_num, number of points in the Path + ratio, (0, 1) magnitude of the perturbation from the unit circle, + ''' + points_num = edge_num*3 + 1 + angles = np.linspace(0, 2*np.pi, points_num) + codes = np.full(points_num, Path.CURVE4) + codes[0] = Path.MOVETO + # Using this instead of Path.CLOSEPOLY avoids an innecessary straight line + verts = np.stack((np.cos(angles), np.sin(angles))).T * \ + (2*ratio*np.random.random(points_num)+1-ratio)[:, None] + verts[-1, :] = verts[0, :] + path = Path(verts, codes) + # draw paths into images + fig = plt.figure() + ax = fig.add_subplot(111) + patch = patches.PathPatch(path, facecolor='black', lw=2) + ax.add_patch(patch) + ax.set_xlim(np.min(verts)*1.1, np.max(verts)*1.1) + ax.set_ylim(np.min(verts)*1.1, np.max(verts)*1.1) + ax.axis('off') # removes the axis to leave only the shape + fig.canvas.draw() + # convert plt images into numpy images + data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) + data = data.reshape((fig.canvas.get_width_height()[::-1] + (3,))) + plt.close(fig) + # postprocess + data = cv2.resize(data, (width, height))[:, :, 0] + data = (1 - np.array(data > 0).astype(np.uint8))*255 + corrdinates = np.where(data > 0) + xmin, xmax, ymin, ymax = np.min(corrdinates[0]), np.max( + corrdinates[0]), np.min(corrdinates[1]), np.max(corrdinates[1]) + region = Image.fromarray(data).crop((ymin, xmin, ymax, xmax)) + return region + + +def random_accelerate(velocity, maxAcceleration, dist='uniform'): + speed, angle = velocity + d_speed, d_angle = maxAcceleration + if dist == 'uniform': + speed += np.random.uniform(-d_speed, d_speed) + angle += np.random.uniform(-d_angle, d_angle) + elif dist == 'guassian': + speed += np.random.normal(0, d_speed / 2) + angle += np.random.normal(0, d_angle / 2) + else: + raise NotImplementedError( + f'Distribution type {dist} is not supported.') + return (speed, angle) + + +def get_random_velocity(max_speed=3, dist='uniform'): + if dist == 'uniform': + speed = np.random.uniform(max_speed) + elif dist == 'guassian': + speed = np.abs(np.random.normal(0, max_speed / 2)) + else: + raise NotImplementedError( + f'Distribution type {dist} is not supported.') + angle = np.random.uniform(0, 2 * np.pi) + return (speed, angle) + + +def random_move_control_points(X, + Y, + imageHeight, + imageWidth, + lineVelocity, + region_size, + maxLineAcceleration=(3, 0.5), + maxInitSpeed=3): + region_width, region_height = region_size + speed, angle = lineVelocity + X += int(speed * np.cos(angle)) + Y += int(speed * np.sin(angle)) + lineVelocity = random_accelerate(lineVelocity, + maxLineAcceleration, + dist='guassian') + if ((X > imageHeight - region_height) or (X < 0) + or (Y > imageWidth - region_width) or (Y < 0)): + lineVelocity = get_random_velocity(maxInitSpeed, dist='guassian') + new_X = np.clip(X, 0, imageHeight - region_height) + new_Y = np.clip(Y, 0, imageWidth - region_width) + return new_X, new_Y, lineVelocity + + +if __name__ == '__main__': + + trials = 10 + for _ in range(trials): + video_length = 10 + # The returned masks are either stationary (50%) or moving (50%) + masks = create_random_shape_with_random_motion(video_length, + imageHeight=240, + imageWidth=432) + + for m in masks: + cv2.imshow('mask', np.array(m)) + cv2.waitKey(500) diff --git a/propainter/inference.py b/propainter/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..b180d49edbcc03df8bb03a3fb8ac89d16b22855d --- /dev/null +++ b/propainter/inference.py @@ -0,0 +1,520 @@ +# -*- coding: utf-8 -*- +import os +import cv2 +import numpy as np +import scipy.ndimage +from PIL import Image +from tqdm import tqdm +import torch +import torchvision +import gc + +try: + from model.modules.flow_comp_raft import RAFT_bi + from model.recurrent_flow_completion import RecurrentFlowCompleteNet + from model.propainter import InpaintGenerator + from utils.download_util import load_file_from_url + from core.utils import to_tensors + from model.misc import get_device +except: + from propainter.model.modules.flow_comp_raft import RAFT_bi + from propainter.model.recurrent_flow_completion import RecurrentFlowCompleteNet + from propainter.model.propainter import InpaintGenerator + from propainter.utils.download_util import load_file_from_url + from propainter.core.utils import to_tensors + from propainter.model.misc import get_device + +import warnings +warnings.filterwarnings("ignore") + +pretrain_model_url = 'https://github.com/sczhou/ProPainter/releases/download/v0.1.0/' +MaxSideThresh = 960 + + +# resize frames +def resize_frames(frames, size=None): + if size is not None: + out_size = size + process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8) + frames = [f.resize(process_size) for f in frames] + else: + out_size = frames[0].size + process_size = (out_size[0]-out_size[0]%8, out_size[1]-out_size[1]%8) + if not out_size == process_size: + frames = [f.resize(process_size) for f in frames] + + return frames, process_size, out_size + +# read frames from video +def read_frame_from_videos(frame_root, video_length): + if frame_root.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path + video_name = os.path.basename(frame_root)[:-4] + vframes, aframes, info = torchvision.io.read_video(filename=frame_root, pts_unit='sec', end_pts=video_length) # RGB + frames = list(vframes.numpy()) + frames = [Image.fromarray(f) for f in frames] + fps = info['video_fps'] + nframes = len(frames) + else: + video_name = os.path.basename(frame_root) + frames = [] + fr_lst = sorted(os.listdir(frame_root)) + for fr in fr_lst: + frame = cv2.imread(os.path.join(frame_root, fr)) + frame = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + frames.append(frame) + fps = None + nframes = len(frames) + size = frames[0].size + + return frames, fps, size, video_name, nframes + +def binary_mask(mask, th=0.1): + mask[mask>th] = 1 + mask[mask<=th] = 0 + return mask + +# read frame-wise masks +def read_mask(mpath, frames_len, size, flow_mask_dilates=8, mask_dilates=5): + masks_img = [] + masks_dilated = [] + flow_masks = [] + + if mpath.endswith(('jpg', 'jpeg', 'png', 'JPG', 'JPEG', 'PNG')): # input single img path + masks_img = [Image.open(mpath)] + elif mpath.endswith(('mp4', 'mov', 'avi', 'MP4', 'MOV', 'AVI')): # input video path + cap = cv2.VideoCapture(mpath) + if not cap.isOpened(): + print("Error: Could not open video.") + exit() + idx = 0 + while True: + ret, frame = cap.read() + if not ret: + break + if(idx >= frames_len): + break + masks_img.append(Image.fromarray(frame)) + idx += 1 + cap.release() + else: + mnames = sorted(os.listdir(mpath)) + for mp in mnames: + masks_img.append(Image.open(os.path.join(mpath, mp))) + # print(mp) + + for mask_img in masks_img: + if size is not None: + mask_img = mask_img.resize(size, Image.NEAREST) + mask_img = np.array(mask_img.convert('L')) + + # Dilate 8 pixel so that all known pixel is trustworthy + if flow_mask_dilates > 0: + flow_mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=flow_mask_dilates).astype(np.uint8) + else: + flow_mask_img = binary_mask(mask_img).astype(np.uint8) + # Close the small holes inside the foreground objects + # flow_mask_img = cv2.morphologyEx(flow_mask_img, cv2.MORPH_CLOSE, np.ones((21, 21),np.uint8)).astype(bool) + # flow_mask_img = scipy.ndimage.binary_fill_holes(flow_mask_img).astype(np.uint8) + flow_masks.append(Image.fromarray(flow_mask_img * 255)) + + if mask_dilates > 0: + mask_img = scipy.ndimage.binary_dilation(mask_img, iterations=mask_dilates).astype(np.uint8) + else: + mask_img = binary_mask(mask_img).astype(np.uint8) + masks_dilated.append(Image.fromarray(mask_img * 255)) + + if len(masks_img) == 1: + flow_masks = flow_masks * frames_len + masks_dilated = masks_dilated * frames_len + + return flow_masks, masks_dilated + +def get_ref_index(mid_neighbor_id, neighbor_ids, length, ref_stride=10, ref_num=-1): + ref_index = [] + if ref_num == -1: + for i in range(0, length, ref_stride): + if i not in neighbor_ids: + ref_index.append(i) + else: + start_idx = max(0, mid_neighbor_id - ref_stride * (ref_num // 2)) + end_idx = min(length, mid_neighbor_id + ref_stride * (ref_num // 2)) + for i in range(start_idx, end_idx, ref_stride): + if i not in neighbor_ids: + if len(ref_index) > ref_num: + break + ref_index.append(i) + return ref_index + + +class Propainter: + def __init__( + self, propainter_model_dir, device): + self.device = device + ############################################## + # set up RAFT and flow competition model + ############################################## + ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'raft-things.pth'), + model_dir=propainter_model_dir, progress=True, file_name=None) + self.fix_raft = RAFT_bi(ckpt_path, device) + + ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'recurrent_flow_completion.pth'), + model_dir=propainter_model_dir, progress=True, file_name=None) + self.fix_flow_complete = RecurrentFlowCompleteNet(ckpt_path) + for p in self.fix_flow_complete.parameters(): + p.requires_grad = False + self.fix_flow_complete.to(device) + self.fix_flow_complete.eval() + + ############################################## + # set up ProPainter model + ############################################## + ckpt_path = load_file_from_url(url=os.path.join(pretrain_model_url, 'ProPainter.pth'), + model_dir=propainter_model_dir, progress=True, file_name=None) + self.model = InpaintGenerator(model_path=ckpt_path).to(device) + self.model.eval() + def forward(self, video, mask, output_path, resize_ratio=0.6, video_length=2, height=-1, width=-1, + mask_dilation=4, ref_stride=10, neighbor_length=10, subvideo_length=80, + raft_iter=20, save_fps=24, save_frames=False, fp16=True): + + # Use fp16 precision during inference to reduce running memory cost + use_half = True if fp16 else False + if self.device == torch.device('cpu'): + use_half = False + + ################ read input video ################ + frames, fps, size, video_name, nframes = read_frame_from_videos(video, video_length) + frames = frames[:nframes] + if not width == -1 and not height == -1: + size = (width, height) + + longer_edge = max(size[0], size[1]) + if(longer_edge > MaxSideThresh): + scale = MaxSideThresh / longer_edge + resize_ratio = resize_ratio * scale + if not resize_ratio == 1.0: + size = (int(resize_ratio * size[0]), int(resize_ratio * size[1])) + + frames, size, out_size = resize_frames(frames, size) + fps = save_fps if fps is None else fps + + ################ read mask ################ + frames_len = len(frames) + flow_masks, masks_dilated = read_mask(mask, frames_len, size, + flow_mask_dilates=mask_dilation, + mask_dilates=mask_dilation) + flow_masks = flow_masks[:nframes] + masks_dilated = masks_dilated[:nframes] + w, h = size + + ################ adjust input ################ + frames_len = min(len(frames), len(masks_dilated)) + frames = frames[:frames_len] + flow_masks = flow_masks[:frames_len] + masks_dilated = masks_dilated[:frames_len] + + ori_frames_inp = [np.array(f).astype(np.uint8) for f in frames] + frames = to_tensors()(frames).unsqueeze(0) * 2 - 1 + flow_masks = to_tensors()(flow_masks).unsqueeze(0) + masks_dilated = to_tensors()(masks_dilated).unsqueeze(0) + frames, flow_masks, masks_dilated = frames.to(self.device), flow_masks.to(self.device), masks_dilated.to(self.device) + + ############################################## + # ProPainter inference + ############################################## + video_length = frames.size(1) + print(f'Priori generating: [{video_length} frames]...') + with torch.no_grad(): + # ---- compute flow ---- + new_longer_edge = max(frames.size(-1), frames.size(-2)) + if new_longer_edge <= 640: + short_clip_len = 12 + elif new_longer_edge <= 720: + short_clip_len = 8 + elif new_longer_edge <= 1280: + short_clip_len = 4 + else: + short_clip_len = 2 + + # use fp32 for RAFT + if frames.size(1) > short_clip_len: + gt_flows_f_list, gt_flows_b_list = [], [] + for f in range(0, video_length, short_clip_len): + end_f = min(video_length, f + short_clip_len) + if f == 0: + flows_f, flows_b = self.fix_raft(frames[:,f:end_f], iters=raft_iter) + else: + flows_f, flows_b = self.fix_raft(frames[:,f-1:end_f], iters=raft_iter) + + gt_flows_f_list.append(flows_f) + gt_flows_b_list.append(flows_b) + torch.cuda.empty_cache() + + gt_flows_f = torch.cat(gt_flows_f_list, dim=1) + gt_flows_b = torch.cat(gt_flows_b_list, dim=1) + gt_flows_bi = (gt_flows_f, gt_flows_b) + else: + gt_flows_bi = self.fix_raft(frames, iters=raft_iter) + torch.cuda.empty_cache() + torch.cuda.empty_cache() + gc.collect() + + if use_half: + frames, flow_masks, masks_dilated = frames.half(), flow_masks.half(), masks_dilated.half() + gt_flows_bi = (gt_flows_bi[0].half(), gt_flows_bi[1].half()) + self.fix_flow_complete = self.fix_flow_complete.half() + self.model = self.model.half() + + # ---- complete flow ---- + flow_length = gt_flows_bi[0].size(1) + if flow_length > subvideo_length: + pred_flows_f, pred_flows_b = [], [] + pad_len = 5 + for f in range(0, flow_length, subvideo_length): + s_f = max(0, f - pad_len) + e_f = min(flow_length, f + subvideo_length + pad_len) + pad_len_s = max(0, f) - s_f + pad_len_e = e_f - min(flow_length, f + subvideo_length) + pred_flows_bi_sub, _ = self.fix_flow_complete.forward_bidirect_flow( + (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]), + flow_masks[:, s_f:e_f+1]) + pred_flows_bi_sub = self.fix_flow_complete.combine_flow( + (gt_flows_bi[0][:, s_f:e_f], gt_flows_bi[1][:, s_f:e_f]), + pred_flows_bi_sub, + flow_masks[:, s_f:e_f+1]) + + pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e]) + pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e]) + torch.cuda.empty_cache() + + pred_flows_f = torch.cat(pred_flows_f, dim=1) + pred_flows_b = torch.cat(pred_flows_b, dim=1) + pred_flows_bi = (pred_flows_f, pred_flows_b) + else: + pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(gt_flows_bi, flow_masks) + pred_flows_bi = self.fix_flow_complete.combine_flow(gt_flows_bi, pred_flows_bi, flow_masks) + torch.cuda.empty_cache() + torch.cuda.empty_cache() + gc.collect() + + + masks_dilated_ori = masks_dilated.clone() + # ---- Pre-propagation ---- + subvideo_length_img_prop = min(100, subvideo_length) # ensure a minimum of 100 frames for image propagation + if(len(frames[0]))>subvideo_length_img_prop: # perform propagation only when length of frames is larger than subvideo_length_img_prop + sample_rate = len(frames[0])//(subvideo_length_img_prop//2) + index_sample = list(range(0, len(frames[0]), sample_rate)) + sample_frames = torch.stack([frames[0][i].to(torch.float32) for i in index_sample]).unsqueeze(0) # use fp32 for RAFT + sample_masks_dilated = torch.stack([masks_dilated[0][i] for i in index_sample]).unsqueeze(0) + sample_flow_masks = torch.stack([flow_masks[0][i] for i in index_sample]).unsqueeze(0) + + ## recompute flow for sampled frames + # use fp32 for RAFT + sample_video_length = sample_frames.size(1) + if sample_frames.size(1) > short_clip_len: + gt_flows_f_list, gt_flows_b_list = [], [] + for f in range(0, sample_video_length, short_clip_len): + end_f = min(sample_video_length, f + short_clip_len) + if f == 0: + flows_f, flows_b = self.fix_raft(sample_frames[:,f:end_f], iters=raft_iter) + else: + flows_f, flows_b = self.fix_raft(sample_frames[:,f-1:end_f], iters=raft_iter) + + gt_flows_f_list.append(flows_f) + gt_flows_b_list.append(flows_b) + torch.cuda.empty_cache() + + gt_flows_f = torch.cat(gt_flows_f_list, dim=1) + gt_flows_b = torch.cat(gt_flows_b_list, dim=1) + sample_gt_flows_bi = (gt_flows_f, gt_flows_b) + else: + sample_gt_flows_bi = self.fix_raft(sample_frames, iters=raft_iter) + torch.cuda.empty_cache() + torch.cuda.empty_cache() + gc.collect() + + if use_half: + sample_frames, sample_flow_masks, sample_masks_dilated = sample_frames.half(), sample_flow_masks.half(), sample_masks_dilated.half() + sample_gt_flows_bi = (sample_gt_flows_bi[0].half(), sample_gt_flows_bi[1].half()) + + # ---- complete flow ---- + flow_length = sample_gt_flows_bi[0].size(1) + if flow_length > subvideo_length: + pred_flows_f, pred_flows_b = [], [] + pad_len = 5 + for f in range(0, flow_length, subvideo_length): + s_f = max(0, f - pad_len) + e_f = min(flow_length, f + subvideo_length + pad_len) + pad_len_s = max(0, f) - s_f + pad_len_e = e_f - min(flow_length, f + subvideo_length) + pred_flows_bi_sub, _ = self.fix_flow_complete.forward_bidirect_flow( + (sample_gt_flows_bi[0][:, s_f:e_f], sample_gt_flows_bi[1][:, s_f:e_f]), + sample_flow_masks[:, s_f:e_f+1]) + pred_flows_bi_sub = self.fix_flow_complete.combine_flow( + (sample_gt_flows_bi[0][:, s_f:e_f], sample_gt_flows_bi[1][:, s_f:e_f]), + pred_flows_bi_sub, + sample_flow_masks[:, s_f:e_f+1]) + + pred_flows_f.append(pred_flows_bi_sub[0][:, pad_len_s:e_f-s_f-pad_len_e]) + pred_flows_b.append(pred_flows_bi_sub[1][:, pad_len_s:e_f-s_f-pad_len_e]) + torch.cuda.empty_cache() + + pred_flows_f = torch.cat(pred_flows_f, dim=1) + pred_flows_b = torch.cat(pred_flows_b, dim=1) + sample_pred_flows_bi = (pred_flows_f, pred_flows_b) + else: + sample_pred_flows_bi, _ = self.fix_flow_complete.forward_bidirect_flow(sample_gt_flows_bi, sample_flow_masks) + sample_pred_flows_bi = self.fix_flow_complete.combine_flow(sample_gt_flows_bi, sample_pred_flows_bi, sample_flow_masks) + torch.cuda.empty_cache() + torch.cuda.empty_cache() + gc.collect() + + masked_frames = sample_frames * (1 - sample_masks_dilated) + + if sample_video_length > subvideo_length_img_prop: + updated_frames, updated_masks = [], [] + pad_len = 10 + for f in range(0, sample_video_length, subvideo_length_img_prop): + s_f = max(0, f - pad_len) + e_f = min(sample_video_length, f + subvideo_length_img_prop + pad_len) + pad_len_s = max(0, f) - s_f + pad_len_e = e_f - min(sample_video_length, f + subvideo_length_img_prop) + + b, t, _, _, _ = sample_masks_dilated[:, s_f:e_f].size() + pred_flows_bi_sub = (sample_pred_flows_bi[0][:, s_f:e_f-1], sample_pred_flows_bi[1][:, s_f:e_f-1]) + prop_imgs_sub, updated_local_masks_sub = self.model.img_propagation(masked_frames[:, s_f:e_f], + pred_flows_bi_sub, + sample_masks_dilated[:, s_f:e_f], + 'nearest') + updated_frames_sub = sample_frames[:, s_f:e_f] * (1 - sample_masks_dilated[:, s_f:e_f]) + \ + prop_imgs_sub.view(b, t, 3, h, w) * sample_masks_dilated[:, s_f:e_f] + updated_masks_sub = updated_local_masks_sub.view(b, t, 1, h, w) + + updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e]) + updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e]) + torch.cuda.empty_cache() + + updated_frames = torch.cat(updated_frames, dim=1) + updated_masks = torch.cat(updated_masks, dim=1) + else: + b, t, _, _, _ = sample_masks_dilated.size() + prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, sample_pred_flows_bi, sample_masks_dilated, 'nearest') + updated_frames = sample_frames * (1 - sample_masks_dilated) + prop_imgs.view(b, t, 3, h, w) * sample_masks_dilated + updated_masks = updated_local_masks.view(b, t, 1, h, w) + torch.cuda.empty_cache() + + ## replace input frames/masks with updated frames/masks + for i,index in enumerate(index_sample): + frames[0][index] = updated_frames[0][i] + masks_dilated[0][index] = updated_masks[0][i] + + + # ---- frame-by-frame image propagation ---- + masked_frames = frames * (1 - masks_dilated) + subvideo_length_img_prop = min(100, subvideo_length) # ensure a minimum of 100 frames for image propagation + if video_length > subvideo_length_img_prop: + updated_frames, updated_masks = [], [] + pad_len = 10 + for f in range(0, video_length, subvideo_length_img_prop): + s_f = max(0, f - pad_len) + e_f = min(video_length, f + subvideo_length_img_prop + pad_len) + pad_len_s = max(0, f) - s_f + pad_len_e = e_f - min(video_length, f + subvideo_length_img_prop) + + b, t, _, _, _ = masks_dilated[:, s_f:e_f].size() + pred_flows_bi_sub = (pred_flows_bi[0][:, s_f:e_f-1], pred_flows_bi[1][:, s_f:e_f-1]) + prop_imgs_sub, updated_local_masks_sub = self.model.img_propagation(masked_frames[:, s_f:e_f], + pred_flows_bi_sub, + masks_dilated[:, s_f:e_f], + 'nearest') + updated_frames_sub = frames[:, s_f:e_f] * (1 - masks_dilated[:, s_f:e_f]) + \ + prop_imgs_sub.view(b, t, 3, h, w) * masks_dilated[:, s_f:e_f] + updated_masks_sub = updated_local_masks_sub.view(b, t, 1, h, w) + + updated_frames.append(updated_frames_sub[:, pad_len_s:e_f-s_f-pad_len_e]) + updated_masks.append(updated_masks_sub[:, pad_len_s:e_f-s_f-pad_len_e]) + torch.cuda.empty_cache() + + updated_frames = torch.cat(updated_frames, dim=1) + updated_masks = torch.cat(updated_masks, dim=1) + else: + b, t, _, _, _ = masks_dilated.size() + prop_imgs, updated_local_masks = self.model.img_propagation(masked_frames, pred_flows_bi, masks_dilated, 'nearest') + updated_frames = frames * (1 - masks_dilated) + prop_imgs.view(b, t, 3, h, w) * masks_dilated + updated_masks = updated_local_masks.view(b, t, 1, h, w) + torch.cuda.empty_cache() + + comp_frames = [None] * video_length + + neighbor_stride = neighbor_length // 2 + if video_length > subvideo_length: + ref_num = subvideo_length // ref_stride + else: + ref_num = -1 + + torch.cuda.empty_cache() + # ---- feature propagation + transformer ---- + for f in tqdm(range(0, video_length, neighbor_stride)): + neighbor_ids = [ + i for i in range(max(0, f - neighbor_stride), + min(video_length, f + neighbor_stride + 1)) + ] + ref_ids = get_ref_index(f, neighbor_ids, video_length, ref_stride, ref_num) + selected_imgs = updated_frames[:, neighbor_ids + ref_ids, :, :, :] + selected_masks = masks_dilated[:, neighbor_ids + ref_ids, :, :, :] + selected_update_masks = updated_masks[:, neighbor_ids + ref_ids, :, :, :] + selected_pred_flows_bi = (pred_flows_bi[0][:, neighbor_ids[:-1], :, :, :], pred_flows_bi[1][:, neighbor_ids[:-1], :, :, :]) + + with torch.no_grad(): + # 1.0 indicates mask + l_t = len(neighbor_ids) + + # pred_img = selected_imgs # results of image propagation + pred_img = self.model(selected_imgs, selected_pred_flows_bi, selected_masks, selected_update_masks, l_t) + pred_img = pred_img.view(-1, 3, h, w) + + ## compose with input frames + pred_img = (pred_img + 1) / 2 + pred_img = pred_img.cpu().permute(0, 2, 3, 1).numpy() * 255 + binary_masks = masks_dilated_ori[0, neighbor_ids, :, :, :].cpu().permute( + 0, 2, 3, 1).numpy().astype(np.uint8) # use original mask + for i in range(len(neighbor_ids)): + idx = neighbor_ids[i] + img = np.array(pred_img[i]).astype(np.uint8) * binary_masks[i] \ + + ori_frames_inp[idx] * (1 - binary_masks[i]) + if comp_frames[idx] is None: + comp_frames[idx] = img + else: + comp_frames[idx] = comp_frames[idx].astype(np.float32) * 0.5 + img.astype(np.float32) * 0.5 + + comp_frames[idx] = comp_frames[idx].astype(np.uint8) + + torch.cuda.empty_cache() + + ##save composed video## + comp_frames = [cv2.resize(f, out_size) for f in comp_frames] + writer = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*"mp4v"), + fps, (comp_frames[0].shape[1],comp_frames[0].shape[0])) + for f in range(video_length): + frame = comp_frames[f].astype(np.uint8) + writer.write(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) + writer.release() + + torch.cuda.empty_cache() + + return output_path + + + +if __name__ == '__main__': + + device = get_device() + propainter_model_dir = "weights/propainter" + propainter = Propainter(propainter_model_dir, device=device) + + video = "examples/example1/video.mp4" + mask = "examples/example1/mask.mp4" + output = "results/priori.mp4" + res = propainter.forward(video, mask, output) + + + \ No newline at end of file diff --git a/propainter/model/__init__.py b/propainter/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc --- /dev/null +++ b/propainter/model/__init__.py @@ -0,0 +1 @@ + diff --git a/propainter/model/canny/canny_filter.py b/propainter/model/canny/canny_filter.py new file mode 100644 index 0000000000000000000000000000000000000000..3d16195c9355b506e22a2ba527006adb9c541a7c --- /dev/null +++ b/propainter/model/canny/canny_filter.py @@ -0,0 +1,256 @@ +import math +from typing import Tuple + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .gaussian import gaussian_blur2d +from .kernels import get_canny_nms_kernel, get_hysteresis_kernel +from .sobel import spatial_gradient + +def rgb_to_grayscale(image, rgb_weights = None): + if len(image.shape) < 3 or image.shape[-3] != 3: + raise ValueError(f"Input size must have a shape of (*, 3, H, W). Got {image.shape}") + + if rgb_weights is None: + # 8 bit images + if image.dtype == torch.uint8: + rgb_weights = torch.tensor([76, 150, 29], device=image.device, dtype=torch.uint8) + # floating point images + elif image.dtype in (torch.float16, torch.float32, torch.float64): + rgb_weights = torch.tensor([0.299, 0.587, 0.114], device=image.device, dtype=image.dtype) + else: + raise TypeError(f"Unknown data type: {image.dtype}") + else: + # is tensor that we make sure is in the same device/dtype + rgb_weights = rgb_weights.to(image) + + # unpack the color image channels with RGB order + r = image[..., 0:1, :, :] + g = image[..., 1:2, :, :] + b = image[..., 2:3, :, :] + + w_r, w_g, w_b = rgb_weights.unbind() + return w_r * r + w_g * g + w_b * b + + +def canny( + input: torch.Tensor, + low_threshold: float = 0.1, + high_threshold: float = 0.2, + kernel_size: Tuple[int, int] = (5, 5), + sigma: Tuple[float, float] = (1, 1), + hysteresis: bool = True, + eps: float = 1e-6, +) -> Tuple[torch.Tensor, torch.Tensor]: + r"""Find edges of the input image and filters them using the Canny algorithm. + + .. image:: _static/img/canny.png + + Args: + input: input image tensor with shape :math:`(B,C,H,W)`. + low_threshold: lower threshold for the hysteresis procedure. + high_threshold: upper threshold for the hysteresis procedure. + kernel_size: the size of the kernel for the gaussian blur. + sigma: the standard deviation of the kernel for the gaussian blur. + hysteresis: if True, applies the hysteresis edge tracking. + Otherwise, the edges are divided between weak (0.5) and strong (1) edges. + eps: regularization number to avoid NaN during backprop. + + Returns: + - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`. + - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`. + + .. note:: + See a working example `here `__. + + Example: + >>> input = torch.rand(5, 3, 4, 4) + >>> magnitude, edges = canny(input) # 5x3x4x4 + >>> magnitude.shape + torch.Size([5, 1, 4, 4]) + >>> edges.shape + torch.Size([5, 1, 4, 4]) + """ + if not isinstance(input, torch.Tensor): + raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") + + if not len(input.shape) == 4: + raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") + + if low_threshold > high_threshold: + raise ValueError( + "Invalid input thresholds. low_threshold should be smaller than the high_threshold. Got: {}>{}".format( + low_threshold, high_threshold + ) + ) + + if low_threshold < 0 and low_threshold > 1: + raise ValueError(f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}") + + if high_threshold < 0 and high_threshold > 1: + raise ValueError(f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}") + + device: torch.device = input.device + dtype: torch.dtype = input.dtype + + # To Grayscale + if input.shape[1] == 3: + input = rgb_to_grayscale(input) + + # Gaussian filter + blurred: torch.Tensor = gaussian_blur2d(input, kernel_size, sigma) + + # Compute the gradients + gradients: torch.Tensor = spatial_gradient(blurred, normalized=False) + + # Unpack the edges + gx: torch.Tensor = gradients[:, :, 0] + gy: torch.Tensor = gradients[:, :, 1] + + # Compute gradient magnitude and angle + magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps) + angle: torch.Tensor = torch.atan2(gy, gx) + + # Radians to Degrees + angle = 180.0 * angle / math.pi + + # Round angle to the nearest 45 degree + angle = torch.round(angle / 45) * 45 + + # Non-maximal suppression + nms_kernels: torch.Tensor = get_canny_nms_kernel(device, dtype) + nms_magnitude: torch.Tensor = F.conv2d(magnitude, nms_kernels, padding=nms_kernels.shape[-1] // 2) + + # Get the indices for both directions + positive_idx: torch.Tensor = (angle / 45) % 8 + positive_idx = positive_idx.long() + + negative_idx: torch.Tensor = ((angle / 45) + 4) % 8 + negative_idx = negative_idx.long() + + # Apply the non-maximum suppression to the different directions + channel_select_filtered_positive: torch.Tensor = torch.gather(nms_magnitude, 1, positive_idx) + channel_select_filtered_negative: torch.Tensor = torch.gather(nms_magnitude, 1, negative_idx) + + channel_select_filtered: torch.Tensor = torch.stack( + [channel_select_filtered_positive, channel_select_filtered_negative], 1 + ) + + is_max: torch.Tensor = channel_select_filtered.min(dim=1)[0] > 0.0 + + magnitude = magnitude * is_max + + # Threshold + edges: torch.Tensor = F.threshold(magnitude, low_threshold, 0.0) + + low: torch.Tensor = magnitude > low_threshold + high: torch.Tensor = magnitude > high_threshold + + edges = low * 0.5 + high * 0.5 + edges = edges.to(dtype) + + # Hysteresis + if hysteresis: + edges_old: torch.Tensor = -torch.ones(edges.shape, device=edges.device, dtype=dtype) + hysteresis_kernels: torch.Tensor = get_hysteresis_kernel(device, dtype) + + while ((edges_old - edges).abs() != 0).any(): + weak: torch.Tensor = (edges == 0.5).float() + strong: torch.Tensor = (edges == 1).float() + + hysteresis_magnitude: torch.Tensor = F.conv2d( + edges, hysteresis_kernels, padding=hysteresis_kernels.shape[-1] // 2 + ) + hysteresis_magnitude = (hysteresis_magnitude == 1).any(1, keepdim=True).to(dtype) + hysteresis_magnitude = hysteresis_magnitude * weak + strong + + edges_old = edges.clone() + edges = hysteresis_magnitude + (hysteresis_magnitude == 0) * weak * 0.5 + + edges = hysteresis_magnitude + + return magnitude, edges + + +class Canny(nn.Module): + r"""Module that finds edges of the input image and filters them using the Canny algorithm. + + Args: + input: input image tensor with shape :math:`(B,C,H,W)`. + low_threshold: lower threshold for the hysteresis procedure. + high_threshold: upper threshold for the hysteresis procedure. + kernel_size: the size of the kernel for the gaussian blur. + sigma: the standard deviation of the kernel for the gaussian blur. + hysteresis: if True, applies the hysteresis edge tracking. + Otherwise, the edges are divided between weak (0.5) and strong (1) edges. + eps: regularization number to avoid NaN during backprop. + + Returns: + - the canny edge magnitudes map, shape of :math:`(B,1,H,W)`. + - the canny edge detection filtered by thresholds and hysteresis, shape of :math:`(B,1,H,W)`. + + Example: + >>> input = torch.rand(5, 3, 4, 4) + >>> magnitude, edges = Canny()(input) # 5x3x4x4 + >>> magnitude.shape + torch.Size([5, 1, 4, 4]) + >>> edges.shape + torch.Size([5, 1, 4, 4]) + """ + + def __init__( + self, + low_threshold: float = 0.1, + high_threshold: float = 0.2, + kernel_size: Tuple[int, int] = (5, 5), + sigma: Tuple[float, float] = (1, 1), + hysteresis: bool = True, + eps: float = 1e-6, + ) -> None: + super().__init__() + + if low_threshold > high_threshold: + raise ValueError( + "Invalid input thresholds. low_threshold should be\ + smaller than the high_threshold. Got: {}>{}".format( + low_threshold, high_threshold + ) + ) + + if low_threshold < 0 or low_threshold > 1: + raise ValueError(f"Invalid input threshold. low_threshold should be in range (0,1). Got: {low_threshold}") + + if high_threshold < 0 or high_threshold > 1: + raise ValueError(f"Invalid input threshold. high_threshold should be in range (0,1). Got: {high_threshold}") + + # Gaussian blur parameters + self.kernel_size = kernel_size + self.sigma = sigma + + # Double threshold + self.low_threshold = low_threshold + self.high_threshold = high_threshold + + # Hysteresis + self.hysteresis = hysteresis + + self.eps: float = eps + + def __repr__(self) -> str: + return ''.join( + ( + f'{type(self).__name__}(', + ', '.join( + f'{name}={getattr(self, name)}' for name in sorted(self.__dict__) if not name.startswith('_') + ), + ')', + ) + ) + + def forward(self, input: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return canny( + input, self.low_threshold, self.high_threshold, self.kernel_size, self.sigma, self.hysteresis, self.eps + ) \ No newline at end of file diff --git a/propainter/model/canny/filter.py b/propainter/model/canny/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..e39d44d67a067c56f994dc9a189f3cf98663bf68 --- /dev/null +++ b/propainter/model/canny/filter.py @@ -0,0 +1,288 @@ +from typing import List + +import torch +import torch.nn.functional as F + +from .kernels import normalize_kernel2d + + +def _compute_padding(kernel_size: List[int]) -> List[int]: + """Compute padding tuple.""" + # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom) + # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad + if len(kernel_size) < 2: + raise AssertionError(kernel_size) + computed = [k - 1 for k in kernel_size] + + # for even kernels we need to do asymmetric padding :( + out_padding = 2 * len(kernel_size) * [0] + + for i in range(len(kernel_size)): + computed_tmp = computed[-(i + 1)] + + pad_front = computed_tmp // 2 + pad_rear = computed_tmp - pad_front + + out_padding[2 * i + 0] = pad_front + out_padding[2 * i + 1] = pad_rear + + return out_padding + + +def filter2d( + input: torch.Tensor, + kernel: torch.Tensor, + border_type: str = 'reflect', + normalized: bool = False, + padding: str = 'same', +) -> torch.Tensor: + r"""Convolve a tensor with a 2d kernel. + + The function applies a given kernel to a tensor. The kernel is applied + independently at each depth channel of the tensor. Before applying the + kernel, the function applies padding according to the specified mode so + that the output remains in the same shape. + + Args: + input: the input tensor with shape of + :math:`(B, C, H, W)`. + kernel: the kernel to be convolved with the input + tensor. The kernel shape must be :math:`(1, kH, kW)` or :math:`(B, kH, kW)`. + border_type: the padding mode to be applied before convolving. + The expected modes are: ``'constant'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. + normalized: If True, kernel will be L1 normalized. + padding: This defines the type of padding. + 2 modes available ``'same'`` or ``'valid'``. + + Return: + torch.Tensor: the convolved tensor of same size and numbers of channels + as the input with shape :math:`(B, C, H, W)`. + + Example: + >>> input = torch.tensor([[[ + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 5., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.],]]]) + >>> kernel = torch.ones(1, 3, 3) + >>> filter2d(input, kernel, padding='same') + tensor([[[[0., 0., 0., 0., 0.], + [0., 5., 5., 5., 0.], + [0., 5., 5., 5., 0.], + [0., 5., 5., 5., 0.], + [0., 0., 0., 0., 0.]]]]) + """ + if not isinstance(input, torch.Tensor): + raise TypeError(f"Input input is not torch.Tensor. Got {type(input)}") + + if not isinstance(kernel, torch.Tensor): + raise TypeError(f"Input kernel is not torch.Tensor. Got {type(kernel)}") + + if not isinstance(border_type, str): + raise TypeError(f"Input border_type is not string. Got {type(border_type)}") + + if border_type not in ['constant', 'reflect', 'replicate', 'circular']: + raise ValueError( + f"Invalid border type, we expect 'constant', \ + 'reflect', 'replicate', 'circular'. Got:{border_type}" + ) + + if not isinstance(padding, str): + raise TypeError(f"Input padding is not string. Got {type(padding)}") + + if padding not in ['valid', 'same']: + raise ValueError(f"Invalid padding mode, we expect 'valid' or 'same'. Got: {padding}") + + if not len(input.shape) == 4: + raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") + + if (not len(kernel.shape) == 3) and not ((kernel.shape[0] == 0) or (kernel.shape[0] == input.shape[0])): + raise ValueError(f"Invalid kernel shape, we expect 1xHxW or BxHxW. Got: {kernel.shape}") + + # prepare kernel + b, c, h, w = input.shape + tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input) + + if normalized: + tmp_kernel = normalize_kernel2d(tmp_kernel) + + tmp_kernel = tmp_kernel.expand(-1, c, -1, -1) + + height, width = tmp_kernel.shape[-2:] + + # pad the input tensor + if padding == 'same': + padding_shape: List[int] = _compute_padding([height, width]) + input = F.pad(input, padding_shape, mode=border_type) + + # kernel and input tensor reshape to align element-wise or batch-wise params + tmp_kernel = tmp_kernel.reshape(-1, 1, height, width) + input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1)) + + # convolve the tensor with the kernel. + output = F.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) + + if padding == 'same': + out = output.view(b, c, h, w) + else: + out = output.view(b, c, h - height + 1, w - width + 1) + + return out + + +def filter2d_separable( + input: torch.Tensor, + kernel_x: torch.Tensor, + kernel_y: torch.Tensor, + border_type: str = 'reflect', + normalized: bool = False, + padding: str = 'same', +) -> torch.Tensor: + r"""Convolve a tensor with two 1d kernels, in x and y directions. + + The function applies a given kernel to a tensor. The kernel is applied + independently at each depth channel of the tensor. Before applying the + kernel, the function applies padding according to the specified mode so + that the output remains in the same shape. + + Args: + input: the input tensor with shape of + :math:`(B, C, H, W)`. + kernel_x: the kernel to be convolved with the input + tensor. The kernel shape must be :math:`(1, kW)` or :math:`(B, kW)`. + kernel_y: the kernel to be convolved with the input + tensor. The kernel shape must be :math:`(1, kH)` or :math:`(B, kH)`. + border_type: the padding mode to be applied before convolving. + The expected modes are: ``'constant'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. + normalized: If True, kernel will be L1 normalized. + padding: This defines the type of padding. + 2 modes available ``'same'`` or ``'valid'``. + + Return: + torch.Tensor: the convolved tensor of same size and numbers of channels + as the input with shape :math:`(B, C, H, W)`. + + Example: + >>> input = torch.tensor([[[ + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 5., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.],]]]) + >>> kernel = torch.ones(1, 3) + + >>> filter2d_separable(input, kernel, kernel, padding='same') + tensor([[[[0., 0., 0., 0., 0.], + [0., 5., 5., 5., 0.], + [0., 5., 5., 5., 0.], + [0., 5., 5., 5., 0.], + [0., 0., 0., 0., 0.]]]]) + """ + out_x = filter2d(input, kernel_x.unsqueeze(0), border_type, normalized, padding) + out = filter2d(out_x, kernel_y.unsqueeze(-1), border_type, normalized, padding) + return out + + +def filter3d( + input: torch.Tensor, kernel: torch.Tensor, border_type: str = 'replicate', normalized: bool = False +) -> torch.Tensor: + r"""Convolve a tensor with a 3d kernel. + + The function applies a given kernel to a tensor. The kernel is applied + independently at each depth channel of the tensor. Before applying the + kernel, the function applies padding according to the specified mode so + that the output remains in the same shape. + + Args: + input: the input tensor with shape of + :math:`(B, C, D, H, W)`. + kernel: the kernel to be convolved with the input + tensor. The kernel shape must be :math:`(1, kD, kH, kW)` or :math:`(B, kD, kH, kW)`. + border_type: the padding mode to be applied before convolving. + The expected modes are: ``'constant'``, + ``'replicate'`` or ``'circular'``. + normalized: If True, kernel will be L1 normalized. + + Return: + the convolved tensor of same size and numbers of channels + as the input with shape :math:`(B, C, D, H, W)`. + + Example: + >>> input = torch.tensor([[[ + ... [[0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.]], + ... [[0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 5., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.]], + ... [[0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.], + ... [0., 0., 0., 0., 0.]] + ... ]]]) + >>> kernel = torch.ones(1, 3, 3, 3) + >>> filter3d(input, kernel) + tensor([[[[[0., 0., 0., 0., 0.], + [0., 5., 5., 5., 0.], + [0., 5., 5., 5., 0.], + [0., 5., 5., 5., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 0., 0., 0.], + [0., 5., 5., 5., 0.], + [0., 5., 5., 5., 0.], + [0., 5., 5., 5., 0.], + [0., 0., 0., 0., 0.]], + + [[0., 0., 0., 0., 0.], + [0., 5., 5., 5., 0.], + [0., 5., 5., 5., 0.], + [0., 5., 5., 5., 0.], + [0., 0., 0., 0., 0.]]]]]) + """ + if not isinstance(input, torch.Tensor): + raise TypeError(f"Input border_type is not torch.Tensor. Got {type(input)}") + + if not isinstance(kernel, torch.Tensor): + raise TypeError(f"Input border_type is not torch.Tensor. Got {type(kernel)}") + + if not isinstance(border_type, str): + raise TypeError(f"Input border_type is not string. Got {type(kernel)}") + + if not len(input.shape) == 5: + raise ValueError(f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}") + + if not len(kernel.shape) == 4 and kernel.shape[0] != 1: + raise ValueError(f"Invalid kernel shape, we expect 1xDxHxW. Got: {kernel.shape}") + + # prepare kernel + b, c, d, h, w = input.shape + tmp_kernel: torch.Tensor = kernel.unsqueeze(1).to(input) + + if normalized: + bk, dk, hk, wk = kernel.shape + tmp_kernel = normalize_kernel2d(tmp_kernel.view(bk, dk, hk * wk)).view_as(tmp_kernel) + + tmp_kernel = tmp_kernel.expand(-1, c, -1, -1, -1) + + # pad the input tensor + depth, height, width = tmp_kernel.shape[-3:] + padding_shape: List[int] = _compute_padding([depth, height, width]) + input_pad: torch.Tensor = F.pad(input, padding_shape, mode=border_type) + + # kernel and input tensor reshape to align element-wise or batch-wise params + tmp_kernel = tmp_kernel.reshape(-1, 1, depth, height, width) + input_pad = input_pad.view(-1, tmp_kernel.size(0), input_pad.size(-3), input_pad.size(-2), input_pad.size(-1)) + + # convolve the tensor with the kernel. + output = F.conv3d(input_pad, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1) + + return output.view(b, c, d, h, w) \ No newline at end of file diff --git a/propainter/model/canny/gaussian.py b/propainter/model/canny/gaussian.py new file mode 100644 index 0000000000000000000000000000000000000000..182f05c5d7d297d97b3dd008287e053493350bb6 --- /dev/null +++ b/propainter/model/canny/gaussian.py @@ -0,0 +1,116 @@ +from typing import Tuple + +import torch +import torch.nn as nn + +from .filter import filter2d, filter2d_separable +from .kernels import get_gaussian_kernel1d, get_gaussian_kernel2d + + +def gaussian_blur2d( + input: torch.Tensor, + kernel_size: Tuple[int, int], + sigma: Tuple[float, float], + border_type: str = 'reflect', + separable: bool = True, +) -> torch.Tensor: + r"""Create an operator that blurs a tensor using a Gaussian filter. + + .. image:: _static/img/gaussian_blur2d.png + + The operator smooths the given tensor with a gaussian kernel by convolving + it to each channel. It supports batched operation. + + Arguments: + input: the input tensor with shape :math:`(B,C,H,W)`. + kernel_size: the size of the kernel. + sigma: the standard deviation of the kernel. + border_type: the padding mode to be applied before convolving. + The expected modes are: ``'constant'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'reflect'``. + separable: run as composition of two 1d-convolutions. + + Returns: + the blurred tensor with shape :math:`(B, C, H, W)`. + + .. note:: + See a working example `here `__. + + Examples: + >>> input = torch.rand(2, 4, 5, 5) + >>> output = gaussian_blur2d(input, (3, 3), (1.5, 1.5)) + >>> output.shape + torch.Size([2, 4, 5, 5]) + """ + if separable: + kernel_x: torch.Tensor = get_gaussian_kernel1d(kernel_size[1], sigma[1]) + kernel_y: torch.Tensor = get_gaussian_kernel1d(kernel_size[0], sigma[0]) + out = filter2d_separable(input, kernel_x[None], kernel_y[None], border_type) + else: + kernel: torch.Tensor = get_gaussian_kernel2d(kernel_size, sigma) + out = filter2d(input, kernel[None], border_type) + return out + + +class GaussianBlur2d(nn.Module): + r"""Create an operator that blurs a tensor using a Gaussian filter. + + The operator smooths the given tensor with a gaussian kernel by convolving + it to each channel. It supports batched operation. + + Arguments: + kernel_size: the size of the kernel. + sigma: the standard deviation of the kernel. + border_type: the padding mode to be applied before convolving. + The expected modes are: ``'constant'``, ``'reflect'``, + ``'replicate'`` or ``'circular'``. Default: ``'reflect'``. + separable: run as composition of two 1d-convolutions. + + Returns: + the blurred tensor. + + Shape: + - Input: :math:`(B, C, H, W)` + - Output: :math:`(B, C, H, W)` + + Examples:: + + >>> input = torch.rand(2, 4, 5, 5) + >>> gauss = GaussianBlur2d((3, 3), (1.5, 1.5)) + >>> output = gauss(input) # 2x4x5x5 + >>> output.shape + torch.Size([2, 4, 5, 5]) + """ + + def __init__( + self, + kernel_size: Tuple[int, int], + sigma: Tuple[float, float], + border_type: str = 'reflect', + separable: bool = True, + ) -> None: + super().__init__() + self.kernel_size: Tuple[int, int] = kernel_size + self.sigma: Tuple[float, float] = sigma + self.border_type = border_type + self.separable = separable + + def __repr__(self) -> str: + return ( + self.__class__.__name__ + + '(kernel_size=' + + str(self.kernel_size) + + ', ' + + 'sigma=' + + str(self.sigma) + + ', ' + + 'border_type=' + + self.border_type + + 'separable=' + + str(self.separable) + + ')' + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return gaussian_blur2d(input, self.kernel_size, self.sigma, self.border_type, self.separable) \ No newline at end of file diff --git a/propainter/model/canny/kernels.py b/propainter/model/canny/kernels.py new file mode 100644 index 0000000000000000000000000000000000000000..ae1ee251b8363ba76c7b63c6925a1776c50b7f32 --- /dev/null +++ b/propainter/model/canny/kernels.py @@ -0,0 +1,690 @@ +import math +from math import sqrt +from typing import List, Optional, Tuple + +import torch + + +def normalize_kernel2d(input: torch.Tensor) -> torch.Tensor: + r"""Normalize both derivative and smoothing kernel.""" + if len(input.size()) < 2: + raise TypeError(f"input should be at least 2D tensor. Got {input.size()}") + norm: torch.Tensor = input.abs().sum(dim=-1).sum(dim=-1) + return input / (norm.unsqueeze(-1).unsqueeze(-1)) + + +def gaussian(window_size: int, sigma: float) -> torch.Tensor: + device, dtype = None, None + if isinstance(sigma, torch.Tensor): + device, dtype = sigma.device, sigma.dtype + x = torch.arange(window_size, device=device, dtype=dtype) - window_size // 2 + if window_size % 2 == 0: + x = x + 0.5 + + gauss = torch.exp((-x.pow(2.0) / (2 * sigma**2)).float()) + return gauss / gauss.sum() + + +def gaussian_discrete_erf(window_size: int, sigma) -> torch.Tensor: + r"""Discrete Gaussian by interpolating the error function. + + Adapted from: + https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py + """ + device = sigma.device if isinstance(sigma, torch.Tensor) else None + sigma = torch.as_tensor(sigma, dtype=torch.float, device=device) + x = torch.arange(window_size).float() - window_size // 2 + t = 0.70710678 / torch.abs(sigma) + gauss = 0.5 * ((t * (x + 0.5)).erf() - (t * (x - 0.5)).erf()) + gauss = gauss.clamp(min=0) + return gauss / gauss.sum() + + +def _modified_bessel_0(x: torch.Tensor) -> torch.Tensor: + r"""Adapted from: + + https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py + """ + if torch.abs(x) < 3.75: + y = (x / 3.75) * (x / 3.75) + return 1.0 + y * ( + 3.5156229 + y * (3.0899424 + y * (1.2067492 + y * (0.2659732 + y * (0.360768e-1 + y * 0.45813e-2)))) + ) + ax = torch.abs(x) + y = 3.75 / ax + ans = 0.916281e-2 + y * (-0.2057706e-1 + y * (0.2635537e-1 + y * (-0.1647633e-1 + y * 0.392377e-2))) + coef = 0.39894228 + y * (0.1328592e-1 + y * (0.225319e-2 + y * (-0.157565e-2 + y * ans))) + return (torch.exp(ax) / torch.sqrt(ax)) * coef + + +def _modified_bessel_1(x: torch.Tensor) -> torch.Tensor: + r"""adapted from: + + https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py + """ + if torch.abs(x) < 3.75: + y = (x / 3.75) * (x / 3.75) + ans = 0.51498869 + y * (0.15084934 + y * (0.2658733e-1 + y * (0.301532e-2 + y * 0.32411e-3))) + return torch.abs(x) * (0.5 + y * (0.87890594 + y * ans)) + ax = torch.abs(x) + y = 3.75 / ax + ans = 0.2282967e-1 + y * (-0.2895312e-1 + y * (0.1787654e-1 - y * 0.420059e-2)) + ans = 0.39894228 + y * (-0.3988024e-1 + y * (-0.362018e-2 + y * (0.163801e-2 + y * (-0.1031555e-1 + y * ans)))) + ans = ans * torch.exp(ax) / torch.sqrt(ax) + return -ans if x < 0.0 else ans + + +def _modified_bessel_i(n: int, x: torch.Tensor) -> torch.Tensor: + r"""adapted from: + + https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py + """ + if n < 2: + raise ValueError("n must be greater than 1.") + if x == 0.0: + return x + device = x.device + tox = 2.0 / torch.abs(x) + ans = torch.tensor(0.0, device=device) + bip = torch.tensor(0.0, device=device) + bi = torch.tensor(1.0, device=device) + m = int(2 * (n + int(sqrt(40.0 * n)))) + for j in range(m, 0, -1): + bim = bip + float(j) * tox * bi + bip = bi + bi = bim + if abs(bi) > 1.0e10: + ans = ans * 1.0e-10 + bi = bi * 1.0e-10 + bip = bip * 1.0e-10 + if j == n: + ans = bip + ans = ans * _modified_bessel_0(x) / bi + return -ans if x < 0.0 and (n % 2) == 1 else ans + + +def gaussian_discrete(window_size, sigma) -> torch.Tensor: + r"""Discrete Gaussian kernel based on the modified Bessel functions. + + Adapted from: + https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py + """ + device = sigma.device if isinstance(sigma, torch.Tensor) else None + sigma = torch.as_tensor(sigma, dtype=torch.float, device=device) + sigma2 = sigma * sigma + tail = int(window_size // 2) + out_pos: List[Optional[torch.Tensor]] = [None] * (tail + 1) + out_pos[0] = _modified_bessel_0(sigma2) + out_pos[1] = _modified_bessel_1(sigma2) + for k in range(2, len(out_pos)): + out_pos[k] = _modified_bessel_i(k, sigma2) + out = out_pos[:0:-1] + out.extend(out_pos) + out = torch.stack(out) * torch.exp(sigma2) # type: ignore + return out / out.sum() # type: ignore + + +def laplacian_1d(window_size) -> torch.Tensor: + r"""One could also use the Laplacian of Gaussian formula to design the filter.""" + + filter_1d = torch.ones(window_size) + filter_1d[window_size // 2] = 1 - window_size + laplacian_1d: torch.Tensor = filter_1d + return laplacian_1d + + +def get_box_kernel2d(kernel_size: Tuple[int, int]) -> torch.Tensor: + r"""Utility function that returns a box filter.""" + kx: float = float(kernel_size[0]) + ky: float = float(kernel_size[1]) + scale: torch.Tensor = torch.tensor(1.0) / torch.tensor([kx * ky]) + tmp_kernel: torch.Tensor = torch.ones(1, kernel_size[0], kernel_size[1]) + return scale.to(tmp_kernel.dtype) * tmp_kernel + + +def get_binary_kernel2d(window_size: Tuple[int, int]) -> torch.Tensor: + r"""Create a binary kernel to extract the patches. + + If the window size is HxW will create a (H*W)xHxW kernel. + """ + window_range: int = window_size[0] * window_size[1] + kernel: torch.Tensor = torch.zeros(window_range, window_range) + for i in range(window_range): + kernel[i, i] += 1.0 + return kernel.view(window_range, 1, window_size[0], window_size[1]) + + +def get_sobel_kernel_3x3() -> torch.Tensor: + """Utility function that returns a sobel kernel of 3x3.""" + return torch.tensor([[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]]) + + +def get_sobel_kernel_5x5_2nd_order() -> torch.Tensor: + """Utility function that returns a 2nd order sobel kernel of 5x5.""" + return torch.tensor( + [ + [-1.0, 0.0, 2.0, 0.0, -1.0], + [-4.0, 0.0, 8.0, 0.0, -4.0], + [-6.0, 0.0, 12.0, 0.0, -6.0], + [-4.0, 0.0, 8.0, 0.0, -4.0], + [-1.0, 0.0, 2.0, 0.0, -1.0], + ] + ) + + +def _get_sobel_kernel_5x5_2nd_order_xy() -> torch.Tensor: + """Utility function that returns a 2nd order sobel kernel of 5x5.""" + return torch.tensor( + [ + [-1.0, -2.0, 0.0, 2.0, 1.0], + [-2.0, -4.0, 0.0, 4.0, 2.0], + [0.0, 0.0, 0.0, 0.0, 0.0], + [2.0, 4.0, 0.0, -4.0, -2.0], + [1.0, 2.0, 0.0, -2.0, -1.0], + ] + ) + + +def get_diff_kernel_3x3() -> torch.Tensor: + """Utility function that returns a first order derivative kernel of 3x3.""" + return torch.tensor([[-0.0, 0.0, 0.0], [-1.0, 0.0, 1.0], [-0.0, 0.0, 0.0]]) + + +def get_diff_kernel3d(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor: + """Utility function that returns a first order derivative kernel of 3x3x3.""" + kernel: torch.Tensor = torch.tensor( + [ + [ + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [-0.5, 0.0, 0.5], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + ], + [ + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, -0.5, 0.0], [0.0, 0.0, 0.0], [0.0, 0.5, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + ], + [ + [[0.0, 0.0, 0.0], [0.0, -0.5, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.5, 0.0], [0.0, 0.0, 0.0]], + ], + ], + device=device, + dtype=dtype, + ) + return kernel.unsqueeze(1) + + +def get_diff_kernel3d_2nd_order(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor: + """Utility function that returns a first order derivative kernel of 3x3x3.""" + kernel: torch.Tensor = torch.tensor( + [ + [ + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [1.0, -2.0, 1.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + ], + [ + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 1.0, 0.0], [0.0, -2.0, 0.0], [0.0, 1.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + ], + [ + [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, -2.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + ], + [ + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[1.0, 0.0, -1.0], [0.0, 0.0, 0.0], [-1.0, 0.0, 1.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + ], + [ + [[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, -1.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, -1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]], + ], + [ + [[0.0, 0.0, 0.0], [1.0, 0.0, -1.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [-1.0, 0.0, 1.0], [0.0, 0.0, 0.0]], + ], + ], + device=device, + dtype=dtype, + ) + return kernel.unsqueeze(1) + + +def get_sobel_kernel2d() -> torch.Tensor: + kernel_x: torch.Tensor = get_sobel_kernel_3x3() + kernel_y: torch.Tensor = kernel_x.transpose(0, 1) + return torch.stack([kernel_x, kernel_y]) + + +def get_diff_kernel2d() -> torch.Tensor: + kernel_x: torch.Tensor = get_diff_kernel_3x3() + kernel_y: torch.Tensor = kernel_x.transpose(0, 1) + return torch.stack([kernel_x, kernel_y]) + + +def get_sobel_kernel2d_2nd_order() -> torch.Tensor: + gxx: torch.Tensor = get_sobel_kernel_5x5_2nd_order() + gyy: torch.Tensor = gxx.transpose(0, 1) + gxy: torch.Tensor = _get_sobel_kernel_5x5_2nd_order_xy() + return torch.stack([gxx, gxy, gyy]) + + +def get_diff_kernel2d_2nd_order() -> torch.Tensor: + gxx: torch.Tensor = torch.tensor([[0.0, 0.0, 0.0], [1.0, -2.0, 1.0], [0.0, 0.0, 0.0]]) + gyy: torch.Tensor = gxx.transpose(0, 1) + gxy: torch.Tensor = torch.tensor([[-1.0, 0.0, 1.0], [0.0, 0.0, 0.0], [1.0, 0.0, -1.0]]) + return torch.stack([gxx, gxy, gyy]) + + +def get_spatial_gradient_kernel2d(mode: str, order: int) -> torch.Tensor: + r"""Function that returns kernel for 1st or 2nd order image gradients, using one of the following operators: + + sobel, diff. + """ + if mode not in ['sobel', 'diff']: + raise TypeError( + "mode should be either sobel\ + or diff. Got {}".format( + mode + ) + ) + if order not in [1, 2]: + raise TypeError( + "order should be either 1 or 2\ + Got {}".format( + order + ) + ) + if mode == 'sobel' and order == 1: + kernel: torch.Tensor = get_sobel_kernel2d() + elif mode == 'sobel' and order == 2: + kernel = get_sobel_kernel2d_2nd_order() + elif mode == 'diff' and order == 1: + kernel = get_diff_kernel2d() + elif mode == 'diff' and order == 2: + kernel = get_diff_kernel2d_2nd_order() + else: + raise NotImplementedError("") + return kernel + + +def get_spatial_gradient_kernel3d(mode: str, order: int, device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor: + r"""Function that returns kernel for 1st or 2nd order scale pyramid gradients, using one of the following + operators: sobel, diff.""" + if mode not in ['sobel', 'diff']: + raise TypeError( + "mode should be either sobel\ + or diff. Got {}".format( + mode + ) + ) + if order not in [1, 2]: + raise TypeError( + "order should be either 1 or 2\ + Got {}".format( + order + ) + ) + if mode == 'sobel': + raise NotImplementedError("Sobel kernel for 3d gradient is not implemented yet") + if mode == 'diff' and order == 1: + kernel = get_diff_kernel3d(device, dtype) + elif mode == 'diff' and order == 2: + kernel = get_diff_kernel3d_2nd_order(device, dtype) + else: + raise NotImplementedError("") + return kernel + + +def get_gaussian_kernel1d(kernel_size: int, sigma: float, force_even: bool = False) -> torch.Tensor: + r"""Function that returns Gaussian filter coefficients. + + Args: + kernel_size: filter size. It should be odd and positive. + sigma: gaussian standard deviation. + force_even: overrides requirement for odd kernel size. + + Returns: + 1D tensor with gaussian filter coefficients. + + Shape: + - Output: :math:`(\text{kernel_size})` + + Examples: + + >>> get_gaussian_kernel1d(3, 2.5) + tensor([0.3243, 0.3513, 0.3243]) + + >>> get_gaussian_kernel1d(5, 1.5) + tensor([0.1201, 0.2339, 0.2921, 0.2339, 0.1201]) + """ + if not isinstance(kernel_size, int) or ((kernel_size % 2 == 0) and not force_even) or (kernel_size <= 0): + raise TypeError("kernel_size must be an odd positive integer. " "Got {}".format(kernel_size)) + window_1d: torch.Tensor = gaussian(kernel_size, sigma) + return window_1d + + +def get_gaussian_discrete_kernel1d(kernel_size: int, sigma: float, force_even: bool = False) -> torch.Tensor: + r"""Function that returns Gaussian filter coefficients based on the modified Bessel functions. Adapted from: + https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py. + + Args: + kernel_size: filter size. It should be odd and positive. + sigma: gaussian standard deviation. + force_even: overrides requirement for odd kernel size. + + Returns: + 1D tensor with gaussian filter coefficients. + + Shape: + - Output: :math:`(\text{kernel_size})` + + Examples: + + >>> get_gaussian_discrete_kernel1d(3, 2.5) + tensor([0.3235, 0.3531, 0.3235]) + + >>> get_gaussian_discrete_kernel1d(5, 1.5) + tensor([0.1096, 0.2323, 0.3161, 0.2323, 0.1096]) + """ + if not isinstance(kernel_size, int) or ((kernel_size % 2 == 0) and not force_even) or (kernel_size <= 0): + raise TypeError("kernel_size must be an odd positive integer. " "Got {}".format(kernel_size)) + window_1d = gaussian_discrete(kernel_size, sigma) + return window_1d + + +def get_gaussian_erf_kernel1d(kernel_size: int, sigma: float, force_even: bool = False) -> torch.Tensor: + r"""Function that returns Gaussian filter coefficients by interpolating the error function, adapted from: + https://github.com/Project-MONAI/MONAI/blob/master/monai/networks/layers/convutils.py. + + Args: + kernel_size: filter size. It should be odd and positive. + sigma: gaussian standard deviation. + force_even: overrides requirement for odd kernel size. + + Returns: + 1D tensor with gaussian filter coefficients. + + Shape: + - Output: :math:`(\text{kernel_size})` + + Examples: + + >>> get_gaussian_erf_kernel1d(3, 2.5) + tensor([0.3245, 0.3511, 0.3245]) + + >>> get_gaussian_erf_kernel1d(5, 1.5) + tensor([0.1226, 0.2331, 0.2887, 0.2331, 0.1226]) + """ + if not isinstance(kernel_size, int) or ((kernel_size % 2 == 0) and not force_even) or (kernel_size <= 0): + raise TypeError("kernel_size must be an odd positive integer. " "Got {}".format(kernel_size)) + window_1d = gaussian_discrete_erf(kernel_size, sigma) + return window_1d + + +def get_gaussian_kernel2d( + kernel_size: Tuple[int, int], sigma: Tuple[float, float], force_even: bool = False +) -> torch.Tensor: + r"""Function that returns Gaussian filter matrix coefficients. + + Args: + kernel_size: filter sizes in the x and y direction. + Sizes should be odd and positive. + sigma: gaussian standard deviation in the x and y + direction. + force_even: overrides requirement for odd kernel size. + + Returns: + 2D tensor with gaussian filter matrix coefficients. + + Shape: + - Output: :math:`(\text{kernel_size}_x, \text{kernel_size}_y)` + + Examples: + >>> get_gaussian_kernel2d((3, 3), (1.5, 1.5)) + tensor([[0.0947, 0.1183, 0.0947], + [0.1183, 0.1478, 0.1183], + [0.0947, 0.1183, 0.0947]]) + >>> get_gaussian_kernel2d((3, 5), (1.5, 1.5)) + tensor([[0.0370, 0.0720, 0.0899, 0.0720, 0.0370], + [0.0462, 0.0899, 0.1123, 0.0899, 0.0462], + [0.0370, 0.0720, 0.0899, 0.0720, 0.0370]]) + """ + if not isinstance(kernel_size, tuple) or len(kernel_size) != 2: + raise TypeError(f"kernel_size must be a tuple of length two. Got {kernel_size}") + if not isinstance(sigma, tuple) or len(sigma) != 2: + raise TypeError(f"sigma must be a tuple of length two. Got {sigma}") + ksize_x, ksize_y = kernel_size + sigma_x, sigma_y = sigma + kernel_x: torch.Tensor = get_gaussian_kernel1d(ksize_x, sigma_x, force_even) + kernel_y: torch.Tensor = get_gaussian_kernel1d(ksize_y, sigma_y, force_even) + kernel_2d: torch.Tensor = torch.matmul(kernel_x.unsqueeze(-1), kernel_y.unsqueeze(-1).t()) + return kernel_2d + + +def get_laplacian_kernel1d(kernel_size: int) -> torch.Tensor: + r"""Function that returns the coefficients of a 1D Laplacian filter. + + Args: + kernel_size: filter size. It should be odd and positive. + + Returns: + 1D tensor with laplacian filter coefficients. + + Shape: + - Output: math:`(\text{kernel_size})` + + Examples: + >>> get_laplacian_kernel1d(3) + tensor([ 1., -2., 1.]) + >>> get_laplacian_kernel1d(5) + tensor([ 1., 1., -4., 1., 1.]) + """ + if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or kernel_size <= 0: + raise TypeError(f"ksize must be an odd positive integer. Got {kernel_size}") + window_1d: torch.Tensor = laplacian_1d(kernel_size) + return window_1d + + +def get_laplacian_kernel2d(kernel_size: int) -> torch.Tensor: + r"""Function that returns Gaussian filter matrix coefficients. + + Args: + kernel_size: filter size should be odd. + + Returns: + 2D tensor with laplacian filter matrix coefficients. + + Shape: + - Output: :math:`(\text{kernel_size}_x, \text{kernel_size}_y)` + + Examples: + >>> get_laplacian_kernel2d(3) + tensor([[ 1., 1., 1.], + [ 1., -8., 1.], + [ 1., 1., 1.]]) + >>> get_laplacian_kernel2d(5) + tensor([[ 1., 1., 1., 1., 1.], + [ 1., 1., 1., 1., 1.], + [ 1., 1., -24., 1., 1.], + [ 1., 1., 1., 1., 1.], + [ 1., 1., 1., 1., 1.]]) + """ + if not isinstance(kernel_size, int) or kernel_size % 2 == 0 or kernel_size <= 0: + raise TypeError(f"ksize must be an odd positive integer. Got {kernel_size}") + + kernel = torch.ones((kernel_size, kernel_size)) + mid = kernel_size // 2 + kernel[mid, mid] = 1 - kernel_size**2 + kernel_2d: torch.Tensor = kernel + return kernel_2d + + +def get_pascal_kernel_2d(kernel_size: int, norm: bool = True) -> torch.Tensor: + """Generate pascal filter kernel by kernel size. + + Args: + kernel_size: height and width of the kernel. + norm: if to normalize the kernel or not. Default: True. + + Returns: + kernel shaped as :math:`(kernel_size, kernel_size)` + + Examples: + >>> get_pascal_kernel_2d(1) + tensor([[1.]]) + >>> get_pascal_kernel_2d(4) + tensor([[0.0156, 0.0469, 0.0469, 0.0156], + [0.0469, 0.1406, 0.1406, 0.0469], + [0.0469, 0.1406, 0.1406, 0.0469], + [0.0156, 0.0469, 0.0469, 0.0156]]) + >>> get_pascal_kernel_2d(4, norm=False) + tensor([[1., 3., 3., 1.], + [3., 9., 9., 3.], + [3., 9., 9., 3.], + [1., 3., 3., 1.]]) + """ + a = get_pascal_kernel_1d(kernel_size) + + filt = a[:, None] * a[None, :] + if norm: + filt = filt / torch.sum(filt) + return filt + + +def get_pascal_kernel_1d(kernel_size: int, norm: bool = False) -> torch.Tensor: + """Generate Yang Hui triangle (Pascal's triangle) by a given number. + + Args: + kernel_size: height and width of the kernel. + norm: if to normalize the kernel or not. Default: False. + + Returns: + kernel shaped as :math:`(kernel_size,)` + + Examples: + >>> get_pascal_kernel_1d(1) + tensor([1.]) + >>> get_pascal_kernel_1d(2) + tensor([1., 1.]) + >>> get_pascal_kernel_1d(3) + tensor([1., 2., 1.]) + >>> get_pascal_kernel_1d(4) + tensor([1., 3., 3., 1.]) + >>> get_pascal_kernel_1d(5) + tensor([1., 4., 6., 4., 1.]) + >>> get_pascal_kernel_1d(6) + tensor([ 1., 5., 10., 10., 5., 1.]) + """ + pre: List[float] = [] + cur: List[float] = [] + for i in range(kernel_size): + cur = [1.0] * (i + 1) + + for j in range(1, i // 2 + 1): + value = pre[j - 1] + pre[j] + cur[j] = value + if i != 2 * j: + cur[-j - 1] = value + pre = cur + + out = torch.as_tensor(cur) + if norm: + out = out / torch.sum(out) + return out + + +def get_canny_nms_kernel(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor: + """Utility function that returns 3x3 kernels for the Canny Non-maximal suppression.""" + kernel: torch.Tensor = torch.tensor( + [ + [[0.0, 0.0, 0.0], [0.0, 1.0, -1.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]], + [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, -1.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [-1.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [-1.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + [[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, -1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, -1.0], [0.0, 1.0, 0.0], [0.0, 0.0, 0.0]], + ], + device=device, + dtype=dtype, + ) + return kernel.unsqueeze(1) + + +def get_hysteresis_kernel(device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor: + """Utility function that returns the 3x3 kernels for the Canny hysteresis.""" + kernel: torch.Tensor = torch.tensor( + [ + [[0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 1.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 1.0, 0.0]], + [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [1.0, 0.0, 0.0]], + [[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[1.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 1.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + [[0.0, 0.0, 1.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]], + ], + device=device, + dtype=dtype, + ) + return kernel.unsqueeze(1) + + +def get_hanning_kernel1d(kernel_size: int, device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor: + r"""Returns Hanning (also known as Hann) kernel, used in signal processing and KCF tracker. + + .. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right) + \\qquad 0 \\leq n \\leq M-1 + + See further in numpy docs https://numpy.org/doc/stable/reference/generated/numpy.hanning.html + + Args: + kernel_size: The size the of the kernel. It should be positive. + + Returns: + 1D tensor with Hanning filter coefficients. + .. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right) + + Shape: + - Output: math:`(\text{kernel_size})` + + Examples: + >>> get_hanning_kernel1d(4) + tensor([0.0000, 0.7500, 0.7500, 0.0000]) + """ + if not isinstance(kernel_size, int) or kernel_size <= 2: + raise TypeError(f"ksize must be an positive integer > 2. Got {kernel_size}") + + x: torch.Tensor = torch.arange(kernel_size, device=device, dtype=dtype) + x = 0.5 - 0.5 * torch.cos(2.0 * math.pi * x / float(kernel_size - 1)) + return x + + +def get_hanning_kernel2d(kernel_size: Tuple[int, int], device=torch.device('cpu'), dtype=torch.float) -> torch.Tensor: + r"""Returns 2d Hanning kernel, used in signal processing and KCF tracker. + + Args: + kernel_size: The size of the kernel for the filter. It should be positive. + + Returns: + 2D tensor with Hanning filter coefficients. + .. math:: w(n) = 0.5 - 0.5cos\\left(\\frac{2\\pi{n}}{M-1}\\right) + + Shape: + - Output: math:`(\text{kernel_size[0], kernel_size[1]})` + """ + if kernel_size[0] <= 2 or kernel_size[1] <= 2: + raise TypeError(f"ksize must be an tuple of positive integers > 2. Got {kernel_size}") + ky: torch.Tensor = get_hanning_kernel1d(kernel_size[0], device, dtype)[None].T + kx: torch.Tensor = get_hanning_kernel1d(kernel_size[1], device, dtype)[None] + kernel2d = ky @ kx + return kernel2d \ No newline at end of file diff --git a/propainter/model/canny/sobel.py b/propainter/model/canny/sobel.py new file mode 100644 index 0000000000000000000000000000000000000000..d780c5c4a22bb6403122a292b6d30fa022f262e8 --- /dev/null +++ b/propainter/model/canny/sobel.py @@ -0,0 +1,263 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .kernels import get_spatial_gradient_kernel2d, get_spatial_gradient_kernel3d, normalize_kernel2d + + +def spatial_gradient(input: torch.Tensor, mode: str = 'sobel', order: int = 1, normalized: bool = True) -> torch.Tensor: + r"""Compute the first order image derivative in both x and y using a Sobel operator. + + .. image:: _static/img/spatial_gradient.png + + Args: + input: input image tensor with shape :math:`(B, C, H, W)`. + mode: derivatives modality, can be: `sobel` or `diff`. + order: the order of the derivatives. + normalized: whether the output is normalized. + + Return: + the derivatives of the input feature map. with shape :math:`(B, C, 2, H, W)`. + + .. note:: + See a working example `here `__. + + Examples: + >>> input = torch.rand(1, 3, 4, 4) + >>> output = spatial_gradient(input) # 1x3x2x4x4 + >>> output.shape + torch.Size([1, 3, 2, 4, 4]) + """ + if not isinstance(input, torch.Tensor): + raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") + + if not len(input.shape) == 4: + raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") + # allocate kernel + kernel: torch.Tensor = get_spatial_gradient_kernel2d(mode, order) + if normalized: + kernel = normalize_kernel2d(kernel) + + # prepare kernel + b, c, h, w = input.shape + tmp_kernel: torch.Tensor = kernel.to(input).detach() + tmp_kernel = tmp_kernel.unsqueeze(1).unsqueeze(1) + + # convolve input tensor with sobel kernel + kernel_flip: torch.Tensor = tmp_kernel.flip(-3) + + # Pad with "replicate for spatial dims, but with zeros for channel + spatial_pad = [kernel.size(1) // 2, kernel.size(1) // 2, kernel.size(2) // 2, kernel.size(2) // 2] + out_channels: int = 3 if order == 2 else 2 + padded_inp: torch.Tensor = F.pad(input.reshape(b * c, 1, h, w), spatial_pad, 'replicate')[:, :, None] + + return F.conv3d(padded_inp, kernel_flip, padding=0).view(b, c, out_channels, h, w) + + +def spatial_gradient3d(input: torch.Tensor, mode: str = 'diff', order: int = 1) -> torch.Tensor: + r"""Compute the first and second order volume derivative in x, y and d using a diff operator. + + Args: + input: input features tensor with shape :math:`(B, C, D, H, W)`. + mode: derivatives modality, can be: `sobel` or `diff`. + order: the order of the derivatives. + + Return: + the spatial gradients of the input feature map with shape math:`(B, C, 3, D, H, W)` + or :math:`(B, C, 6, D, H, W)`. + + Examples: + >>> input = torch.rand(1, 4, 2, 4, 4) + >>> output = spatial_gradient3d(input) + >>> output.shape + torch.Size([1, 4, 3, 2, 4, 4]) + """ + if not isinstance(input, torch.Tensor): + raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") + + if not len(input.shape) == 5: + raise ValueError(f"Invalid input shape, we expect BxCxDxHxW. Got: {input.shape}") + b, c, d, h, w = input.shape + dev = input.device + dtype = input.dtype + if (mode == 'diff') and (order == 1): + # we go for the special case implementation due to conv3d bad speed + x: torch.Tensor = F.pad(input, 6 * [1], 'replicate') + center = slice(1, -1) + left = slice(0, -2) + right = slice(2, None) + out = torch.empty(b, c, 3, d, h, w, device=dev, dtype=dtype) + out[..., 0, :, :, :] = x[..., center, center, right] - x[..., center, center, left] + out[..., 1, :, :, :] = x[..., center, right, center] - x[..., center, left, center] + out[..., 2, :, :, :] = x[..., right, center, center] - x[..., left, center, center] + out = 0.5 * out + else: + # prepare kernel + # allocate kernel + kernel: torch.Tensor = get_spatial_gradient_kernel3d(mode, order) + + tmp_kernel: torch.Tensor = kernel.to(input).detach() + tmp_kernel = tmp_kernel.repeat(c, 1, 1, 1, 1) + + # convolve input tensor with grad kernel + kernel_flip: torch.Tensor = tmp_kernel.flip(-3) + + # Pad with "replicate for spatial dims, but with zeros for channel + spatial_pad = [ + kernel.size(2) // 2, + kernel.size(2) // 2, + kernel.size(3) // 2, + kernel.size(3) // 2, + kernel.size(4) // 2, + kernel.size(4) // 2, + ] + out_ch: int = 6 if order == 2 else 3 + out = F.conv3d(F.pad(input, spatial_pad, 'replicate'), kernel_flip, padding=0, groups=c).view( + b, c, out_ch, d, h, w + ) + return out + + +def sobel(input: torch.Tensor, normalized: bool = True, eps: float = 1e-6) -> torch.Tensor: + r"""Compute the Sobel operator and returns the magnitude per channel. + + .. image:: _static/img/sobel.png + + Args: + input: the input image with shape :math:`(B,C,H,W)`. + normalized: if True, L1 norm of the kernel is set to 1. + eps: regularization number to avoid NaN during backprop. + + Return: + the sobel edge gradient magnitudes map with shape :math:`(B,C,H,W)`. + + .. note:: + See a working example `here `__. + + Example: + >>> input = torch.rand(1, 3, 4, 4) + >>> output = sobel(input) # 1x3x4x4 + >>> output.shape + torch.Size([1, 3, 4, 4]) + """ + if not isinstance(input, torch.Tensor): + raise TypeError(f"Input type is not a torch.Tensor. Got {type(input)}") + + if not len(input.shape) == 4: + raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {input.shape}") + + # comput the x/y gradients + edges: torch.Tensor = spatial_gradient(input, normalized=normalized) + + # unpack the edges + gx: torch.Tensor = edges[:, :, 0] + gy: torch.Tensor = edges[:, :, 1] + + # compute gradient maginitude + magnitude: torch.Tensor = torch.sqrt(gx * gx + gy * gy + eps) + + return magnitude + + +class SpatialGradient(nn.Module): + r"""Compute the first order image derivative in both x and y using a Sobel operator. + + Args: + mode: derivatives modality, can be: `sobel` or `diff`. + order: the order of the derivatives. + normalized: whether the output is normalized. + + Return: + the sobel edges of the input feature map. + + Shape: + - Input: :math:`(B, C, H, W)` + - Output: :math:`(B, C, 2, H, W)` + + Examples: + >>> input = torch.rand(1, 3, 4, 4) + >>> output = SpatialGradient()(input) # 1x3x2x4x4 + """ + + def __init__(self, mode: str = 'sobel', order: int = 1, normalized: bool = True) -> None: + super().__init__() + self.normalized: bool = normalized + self.order: int = order + self.mode: str = mode + + def __repr__(self) -> str: + return ( + self.__class__.__name__ + '(' + 'order=' + str(self.order) + ', ' + 'normalized=' + str(self.normalized) + ', ' + 'mode=' + self.mode + ')' + ) + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return spatial_gradient(input, self.mode, self.order, self.normalized) + + +class SpatialGradient3d(nn.Module): + r"""Compute the first and second order volume derivative in x, y and d using a diff operator. + + Args: + mode: derivatives modality, can be: `sobel` or `diff`. + order: the order of the derivatives. + + Return: + the spatial gradients of the input feature map. + + Shape: + - Input: :math:`(B, C, D, H, W)`. D, H, W are spatial dimensions, gradient is calculated w.r.t to them. + - Output: :math:`(B, C, 3, D, H, W)` or :math:`(B, C, 6, D, H, W)` + + Examples: + >>> input = torch.rand(1, 4, 2, 4, 4) + >>> output = SpatialGradient3d()(input) + >>> output.shape + torch.Size([1, 4, 3, 2, 4, 4]) + """ + + def __init__(self, mode: str = 'diff', order: int = 1) -> None: + super().__init__() + self.order: int = order + self.mode: str = mode + self.kernel = get_spatial_gradient_kernel3d(mode, order) + return + + def __repr__(self) -> str: + return self.__class__.__name__ + '(' 'order=' + str(self.order) + ', ' + 'mode=' + self.mode + ')' + + def forward(self, input: torch.Tensor) -> torch.Tensor: # type: ignore + return spatial_gradient3d(input, self.mode, self.order) + + +class Sobel(nn.Module): + r"""Compute the Sobel operator and returns the magnitude per channel. + + Args: + normalized: if True, L1 norm of the kernel is set to 1. + eps: regularization number to avoid NaN during backprop. + + Return: + the sobel edge gradient magnitudes map. + + Shape: + - Input: :math:`(B, C, H, W)` + - Output: :math:`(B, C, H, W)` + + Examples: + >>> input = torch.rand(1, 3, 4, 4) + >>> output = Sobel()(input) # 1x3x4x4 + """ + + def __init__(self, normalized: bool = True, eps: float = 1e-6) -> None: + super().__init__() + self.normalized: bool = normalized + self.eps: float = eps + + def __repr__(self) -> str: + return self.__class__.__name__ + '(' 'normalized=' + str(self.normalized) + ')' + + def forward(self, input: torch.Tensor) -> torch.Tensor: + return sobel(input, self.normalized, self.eps) \ No newline at end of file diff --git a/propainter/model/misc.py b/propainter/model/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..43b849902245dd338a36f4f4ff09e33425365af6 --- /dev/null +++ b/propainter/model/misc.py @@ -0,0 +1,131 @@ +import os +import re +import random +import time +import torch +import torch.nn as nn +import logging +import numpy as np +from os import path as osp + +def constant_init(module, val, bias=0): + if hasattr(module, 'weight') and module.weight is not None: + nn.init.constant_(module.weight, val) + if hasattr(module, 'bias') and module.bias is not None: + nn.init.constant_(module.bias, bias) + +initialized_logger = {} +def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): + """Get the root logger. + The logger will be initialized if it has not been initialized. By default a + StreamHandler will be added. If `log_file` is specified, a FileHandler will + also be added. + Args: + logger_name (str): root logger name. Default: 'basicsr'. + log_file (str | None): The log filename. If specified, a FileHandler + will be added to the root logger. + log_level (int): The root logger level. Note that only the process of + rank 0 is affected, while other processes will set the level to + "Error" and be silent most of the time. + Returns: + logging.Logger: The root logger. + """ + logger = logging.getLogger(logger_name) + # if the logger has been initialized, just return it + if logger_name in initialized_logger: + return logger + + format_str = '%(asctime)s %(levelname)s: %(message)s' + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(logging.Formatter(format_str)) + logger.addHandler(stream_handler) + logger.propagate = False + + if log_file is not None: + logger.setLevel(log_level) + # add file handler + # file_handler = logging.FileHandler(log_file, 'w') + file_handler = logging.FileHandler(log_file, 'a') #Shangchen: keep the previous log + file_handler.setFormatter(logging.Formatter(format_str)) + file_handler.setLevel(log_level) + logger.addHandler(file_handler) + initialized_logger[logger_name] = True + return logger + + +IS_HIGH_VERSION = [int(m) for m in list(re.findall(r"^([0-9]+)\.([0-9]+)\.([0-9]+)([^0-9][a-zA-Z0-9]*)?(\+git.*)?$",\ + torch.__version__)[0][:3])] >= [1, 12, 0] + +def gpu_is_available(): + if IS_HIGH_VERSION: + if torch.backends.mps.is_available(): + return True + return True if torch.cuda.is_available() and torch.backends.cudnn.is_available() else False + +def get_device(gpu_id=None): + if gpu_id is None: + gpu_str = '' + elif isinstance(gpu_id, int): + gpu_str = f':{gpu_id}' + else: + raise TypeError('Input should be int value.') + + if IS_HIGH_VERSION: + if torch.backends.mps.is_available(): + return torch.device('mps'+gpu_str) + return torch.device('cuda'+gpu_str if torch.cuda.is_available() and torch.backends.cudnn.is_available() else 'cpu') + + +def set_random_seed(seed): + """Set random seeds.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_time_str(): + return time.strftime('%Y%m%d_%H%M%S', time.localtime()) + + +def scandir(dir_path, suffix=None, recursive=False, full_path=False): + """Scan a directory to find the interested files. + + Args: + dir_path (str): Path of the directory. + suffix (str | tuple(str), optional): File suffix that we are + interested in. Default: None. + recursive (bool, optional): If set to True, recursively scan the + directory. Default: False. + full_path (bool, optional): If set to True, include the dir_path. + Default: False. + + Returns: + A generator for all the interested files with relative pathes. + """ + + if (suffix is not None) and not isinstance(suffix, (str, tuple)): + raise TypeError('"suffix" must be a string or tuple of strings') + + root = dir_path + + def _scandir(dir_path, suffix, recursive): + for entry in os.scandir(dir_path): + if not entry.name.startswith('.') and entry.is_file(): + if full_path: + return_path = entry.path + else: + return_path = osp.relpath(entry.path, root) + + if suffix is None: + yield return_path + elif return_path.endswith(suffix): + yield return_path + else: + if recursive: + yield from _scandir(entry.path, suffix=suffix, recursive=recursive) + else: + continue + + return _scandir(dir_path, suffix=suffix, recursive=recursive) \ No newline at end of file diff --git a/propainter/model/modules/base_module.py b/propainter/model/modules/base_module.py new file mode 100644 index 0000000000000000000000000000000000000000..b28c094308dd4d1bbb62dd75e02e937e2c9ddf14 --- /dev/null +++ b/propainter/model/modules/base_module.py @@ -0,0 +1,131 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from functools import reduce + +class BaseNetwork(nn.Module): + def __init__(self): + super(BaseNetwork, self).__init__() + + def print_network(self): + if isinstance(self, list): + self = self[0] + num_params = 0 + for param in self.parameters(): + num_params += param.numel() + print( + 'Network [%s] was created. Total number of parameters: %.1f million. ' + 'To see the architecture, do print(network).' % + (type(self).__name__, num_params / 1000000)) + + def init_weights(self, init_type='normal', gain=0.02): + ''' + initialize network's weights + init_type: normal | xavier | kaiming | orthogonal + https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/9451e70673400885567d08a9e97ade2524c700d0/models/networks.py#L39 + ''' + def init_func(m): + classname = m.__class__.__name__ + if classname.find('InstanceNorm2d') != -1: + if hasattr(m, 'weight') and m.weight is not None: + nn.init.constant_(m.weight.data, 1.0) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + elif hasattr(m, 'weight') and (classname.find('Conv') != -1 + or classname.find('Linear') != -1): + if init_type == 'normal': + nn.init.normal_(m.weight.data, 0.0, gain) + elif init_type == 'xavier': + nn.init.xavier_normal_(m.weight.data, gain=gain) + elif init_type == 'xavier_uniform': + nn.init.xavier_uniform_(m.weight.data, gain=1.0) + elif init_type == 'kaiming': + nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') + elif init_type == 'orthogonal': + nn.init.orthogonal_(m.weight.data, gain=gain) + elif init_type == 'none': # uses pytorch's default init method + m.reset_parameters() + else: + raise NotImplementedError( + 'initialization method [%s] is not implemented' % + init_type) + if hasattr(m, 'bias') and m.bias is not None: + nn.init.constant_(m.bias.data, 0.0) + + self.apply(init_func) + + # propagate to children + for m in self.children(): + if hasattr(m, 'init_weights'): + m.init_weights(init_type, gain) + + +class Vec2Feat(nn.Module): + def __init__(self, channel, hidden, kernel_size, stride, padding): + super(Vec2Feat, self).__init__() + self.relu = nn.LeakyReLU(0.2, inplace=True) + c_out = reduce((lambda x, y: x * y), kernel_size) * channel + self.embedding = nn.Linear(hidden, c_out) + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.bias_conv = nn.Conv2d(channel, + channel, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t, output_size): + b_, _, _, _, c_ = x.shape + x = x.view(b_, -1, c_) + feat = self.embedding(x) + b, _, c = feat.size() + feat = feat.view(b * t, -1, c).permute(0, 2, 1) + feat = F.fold(feat, + output_size=output_size, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding) + feat = self.bias_conv(feat) + return feat + + +class FusionFeedForward(nn.Module): + def __init__(self, dim, hidden_dim=1960, t2t_params=None): + super(FusionFeedForward, self).__init__() + # We set hidden_dim as a default to 1960 + self.fc1 = nn.Sequential(nn.Linear(dim, hidden_dim)) + self.fc2 = nn.Sequential(nn.GELU(), nn.Linear(hidden_dim, dim)) + assert t2t_params is not None + self.t2t_params = t2t_params + self.kernel_shape = reduce((lambda x, y: x * y), t2t_params['kernel_size']) # 49 + + def forward(self, x, output_size): + n_vecs = 1 + for i, d in enumerate(self.t2t_params['kernel_size']): + n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] - + (d - 1) - 1) / self.t2t_params['stride'][i] + 1) + + x = self.fc1(x) + b, n, c = x.size() + normalizer = x.new_ones(b, n, self.kernel_shape).view(-1, n_vecs, self.kernel_shape).permute(0, 2, 1) + normalizer = F.fold(normalizer, + output_size=output_size, + kernel_size=self.t2t_params['kernel_size'], + padding=self.t2t_params['padding'], + stride=self.t2t_params['stride']) + + x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1), + output_size=output_size, + kernel_size=self.t2t_params['kernel_size'], + padding=self.t2t_params['padding'], + stride=self.t2t_params['stride']) + + x = F.unfold(x / normalizer, + kernel_size=self.t2t_params['kernel_size'], + padding=self.t2t_params['padding'], + stride=self.t2t_params['stride']).permute( + 0, 2, 1).contiguous().view(b, n, c) + x = self.fc2(x) + return x diff --git a/propainter/model/modules/deformconv.py b/propainter/model/modules/deformconv.py new file mode 100644 index 0000000000000000000000000000000000000000..89cb31b3d80bd69704a380930964db6fb29a6bbe --- /dev/null +++ b/propainter/model/modules/deformconv.py @@ -0,0 +1,54 @@ +import torch +import torch.nn as nn +from torch.nn import init as init +from torch.nn.modules.utils import _pair, _single +import math + +class ModulatedDeformConv2d(nn.Module): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + padding=0, + dilation=1, + groups=1, + deform_groups=1, + bias=True): + super(ModulatedDeformConv2d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = _pair(kernel_size) + self.stride = stride + self.padding = padding + self.dilation = dilation + self.groups = groups + self.deform_groups = deform_groups + self.with_bias = bias + # enable compatibility with nn.Conv2d + self.transposed = False + self.output_padding = _single(0) + + self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels // groups, *self.kernel_size)) + if bias: + self.bias = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter('bias', None) + self.init_weights() + + def init_weights(self): + n = self.in_channels + for k in self.kernel_size: + n *= k + stdv = 1. / math.sqrt(n) + self.weight.data.uniform_(-stdv, stdv) + if self.bias is not None: + self.bias.data.zero_() + + if hasattr(self, 'conv_offset'): + self.conv_offset.weight.data.zero_() + self.conv_offset.bias.data.zero_() + + def forward(self, x, offset, mask): + pass \ No newline at end of file diff --git a/propainter/model/modules/flow_comp_raft.py b/propainter/model/modules/flow_comp_raft.py new file mode 100644 index 0000000000000000000000000000000000000000..b7f09eba83186726fcc112b7b80172dcd74d5d69 --- /dev/null +++ b/propainter/model/modules/flow_comp_raft.py @@ -0,0 +1,270 @@ +import argparse +import torch +import torch.nn as nn +import torch.nn.functional as F + +try: + from RAFT import RAFT + from model.modules.flow_loss_utils import flow_warp, ternary_loss2 +except: + from propainter.RAFT import RAFT + from propainter.model.modules.flow_loss_utils import flow_warp, ternary_loss2 + + + +def initialize_RAFT(model_path='weights/raft-things.pth', device='cuda'): + """Initializes the RAFT model. + """ + args = argparse.ArgumentParser() + args.raft_model = model_path + args.small = False + args.mixed_precision = False + args.alternate_corr = False + model = torch.nn.DataParallel(RAFT(args)) + model.load_state_dict(torch.load(args.raft_model, map_location='cpu')) + model = model.module + + model.to(device) + + return model + + +class RAFT_bi(nn.Module): + """Flow completion loss""" + def __init__(self, model_path='weights/raft-things.pth', device='cuda'): + super().__init__() + self.fix_raft = initialize_RAFT(model_path, device=device) + + for p in self.fix_raft.parameters(): + p.requires_grad = False + + self.l1_criterion = nn.L1Loss() + self.eval() + + def forward(self, gt_local_frames, iters=20): + b, l_t, c, h, w = gt_local_frames.size() + # print(gt_local_frames.shape) + + with torch.no_grad(): + gtlf_1 = gt_local_frames[:, :-1, :, :, :].reshape(-1, c, h, w) + gtlf_2 = gt_local_frames[:, 1:, :, :, :].reshape(-1, c, h, w) + # print(gtlf_1.shape) + + _, gt_flows_forward = self.fix_raft(gtlf_1, gtlf_2, iters=iters, test_mode=True) + _, gt_flows_backward = self.fix_raft(gtlf_2, gtlf_1, iters=iters, test_mode=True) + + + gt_flows_forward = gt_flows_forward.view(b, l_t-1, 2, h, w) + gt_flows_backward = gt_flows_backward.view(b, l_t-1, 2, h, w) + + return gt_flows_forward, gt_flows_backward + + +################################################################################## +def smoothness_loss(flow, cmask): + delta_u, delta_v, mask = smoothness_deltas(flow) + loss_u = charbonnier_loss(delta_u, cmask) + loss_v = charbonnier_loss(delta_v, cmask) + return loss_u + loss_v + + +def smoothness_deltas(flow): + """ + flow: [b, c, h, w] + """ + mask_x = create_mask(flow, [[0, 0], [0, 1]]) + mask_y = create_mask(flow, [[0, 1], [0, 0]]) + mask = torch.cat((mask_x, mask_y), dim=1) + mask = mask.to(flow.device) + filter_x = torch.tensor([[0, 0, 0.], [0, 1, -1], [0, 0, 0]]) + filter_y = torch.tensor([[0, 0, 0.], [0, 1, 0], [0, -1, 0]]) + weights = torch.ones([2, 1, 3, 3]) + weights[0, 0] = filter_x + weights[1, 0] = filter_y + weights = weights.to(flow.device) + + flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1) + delta_u = F.conv2d(flow_u, weights, stride=1, padding=1) + delta_v = F.conv2d(flow_v, weights, stride=1, padding=1) + return delta_u, delta_v, mask + + +def second_order_loss(flow, cmask): + delta_u, delta_v, mask = second_order_deltas(flow) + loss_u = charbonnier_loss(delta_u, cmask) + loss_v = charbonnier_loss(delta_v, cmask) + return loss_u + loss_v + + +def charbonnier_loss(x, mask=None, truncate=None, alpha=0.45, beta=1.0, epsilon=0.001): + """ + Compute the generalized charbonnier loss of the difference tensor x + All positions where mask == 0 are not taken into account + x: a tensor of shape [b, c, h, w] + mask: a mask of shape [b, mc, h, w], where mask channels must be either 1 or the same as + the number of channels of x. Entries should be 0 or 1 + return: loss + """ + b, c, h, w = x.shape + norm = b * c * h * w + error = torch.pow(torch.square(x * beta) + torch.square(torch.tensor(epsilon)), alpha) + if mask is not None: + error = mask * error + if truncate is not None: + error = torch.min(error, truncate) + return torch.sum(error) / norm + + +def second_order_deltas(flow): + """ + consider the single flow first + flow shape: [b, c, h, w] + """ + # create mask + mask_x = create_mask(flow, [[0, 0], [1, 1]]) + mask_y = create_mask(flow, [[1, 1], [0, 0]]) + mask_diag = create_mask(flow, [[1, 1], [1, 1]]) + mask = torch.cat((mask_x, mask_y, mask_diag, mask_diag), dim=1) + mask = mask.to(flow.device) + + filter_x = torch.tensor([[0, 0, 0.], [1, -2, 1], [0, 0, 0]]) + filter_y = torch.tensor([[0, 1, 0.], [0, -2, 0], [0, 1, 0]]) + filter_diag1 = torch.tensor([[1, 0, 0.], [0, -2, 0], [0, 0, 1]]) + filter_diag2 = torch.tensor([[0, 0, 1.], [0, -2, 0], [1, 0, 0]]) + weights = torch.ones([4, 1, 3, 3]) + weights[0] = filter_x + weights[1] = filter_y + weights[2] = filter_diag1 + weights[3] = filter_diag2 + weights = weights.to(flow.device) + + # split the flow into flow_u and flow_v, conv them with the weights + flow_u, flow_v = torch.split(flow, split_size_or_sections=1, dim=1) + delta_u = F.conv2d(flow_u, weights, stride=1, padding=1) + delta_v = F.conv2d(flow_v, weights, stride=1, padding=1) + return delta_u, delta_v, mask + +def create_mask(tensor, paddings): + """ + tensor shape: [b, c, h, w] + paddings: [2 x 2] shape list, the first row indicates up and down paddings + the second row indicates left and right paddings + | | + | x | + | x * x | + | x | + | | + """ + shape = tensor.shape + inner_height = shape[2] - (paddings[0][0] + paddings[0][1]) + inner_width = shape[3] - (paddings[1][0] + paddings[1][1]) + inner = torch.ones([inner_height, inner_width]) + torch_paddings = [paddings[1][0], paddings[1][1], paddings[0][0], paddings[0][1]] # left, right, up and down + mask2d = F.pad(inner, pad=torch_paddings) + mask3d = mask2d.unsqueeze(0).repeat(shape[0], 1, 1) + mask4d = mask3d.unsqueeze(1) + return mask4d.detach() + +def ternary_loss(flow_comp, flow_gt, mask, current_frame, shift_frame, scale_factor=1): + if scale_factor != 1: + current_frame = F.interpolate(current_frame, scale_factor=1 / scale_factor, mode='bilinear') + shift_frame = F.interpolate(shift_frame, scale_factor=1 / scale_factor, mode='bilinear') + warped_sc = flow_warp(shift_frame, flow_gt.permute(0, 2, 3, 1)) + noc_mask = torch.exp(-50. * torch.sum(torch.abs(current_frame - warped_sc), dim=1).pow(2)).unsqueeze(1) + warped_comp_sc = flow_warp(shift_frame, flow_comp.permute(0, 2, 3, 1)) + loss = ternary_loss2(current_frame, warped_comp_sc, noc_mask, mask) + return loss + +class FlowLoss(nn.Module): + def __init__(self): + super().__init__() + self.l1_criterion = nn.L1Loss() + + def forward(self, pred_flows, gt_flows, masks, frames): + # pred_flows: b t-1 2 h w + loss = 0 + warp_loss = 0 + h, w = pred_flows[0].shape[-2:] + masks = [masks[:,:-1,...].contiguous(), masks[:, 1:, ...].contiguous()] + frames0 = frames[:,:-1,...] + frames1 = frames[:,1:,...] + current_frames = [frames0, frames1] + next_frames = [frames1, frames0] + for i in range(len(pred_flows)): + # print(pred_flows[i].shape) + combined_flow = pred_flows[i] * masks[i] + gt_flows[i] * (1-masks[i]) + l1_loss = self.l1_criterion(pred_flows[i] * masks[i], gt_flows[i] * masks[i]) / torch.mean(masks[i]) + l1_loss += self.l1_criterion(pred_flows[i] * (1-masks[i]), gt_flows[i] * (1-masks[i])) / torch.mean((1-masks[i])) + + smooth_loss = smoothness_loss(combined_flow.reshape(-1,2,h,w), masks[i].reshape(-1,1,h,w)) + smooth_loss2 = second_order_loss(combined_flow.reshape(-1,2,h,w), masks[i].reshape(-1,1,h,w)) + + warp_loss_i = ternary_loss(combined_flow.reshape(-1,2,h,w), gt_flows[i].reshape(-1,2,h,w), + masks[i].reshape(-1,1,h,w), current_frames[i].reshape(-1,3,h,w), next_frames[i].reshape(-1,3,h,w)) + + loss += l1_loss + smooth_loss + smooth_loss2 + + warp_loss += warp_loss_i + + return loss, warp_loss + + +def edgeLoss(preds_edges, edges): + """ + + Args: + preds_edges: with shape [b, c, h , w] + edges: with shape [b, c, h, w] + + Returns: Edge losses + + """ + mask = (edges > 0.5).float() + b, c, h, w = mask.shape + num_pos = torch.sum(mask, dim=[1, 2, 3]).float() # Shape: [b,]. + num_neg = c * h * w - num_pos # Shape: [b,]. + neg_weights = (num_neg / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3) + pos_weights = (num_pos / (num_pos + num_neg)).unsqueeze(1).unsqueeze(2).unsqueeze(3) + weight = neg_weights * mask + pos_weights * (1 - mask) # weight for debug + losses = F.binary_cross_entropy_with_logits(preds_edges.float(), edges.float(), weight=weight, reduction='none') + loss = torch.mean(losses) + return loss + +class EdgeLoss(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, pred_edges, gt_edges, masks): + # pred_flows: b t-1 1 h w + loss = 0 + h, w = pred_edges[0].shape[-2:] + masks = [masks[:,:-1,...].contiguous(), masks[:, 1:, ...].contiguous()] + for i in range(len(pred_edges)): + # print(f'edges_{i}', torch.sum(gt_edges[i])) # debug + combined_edge = pred_edges[i] * masks[i] + gt_edges[i] * (1-masks[i]) + edge_loss = (edgeLoss(pred_edges[i].reshape(-1,1,h,w), gt_edges[i].reshape(-1,1,h,w)) \ + + 5 * edgeLoss(combined_edge.reshape(-1,1,h,w), gt_edges[i].reshape(-1,1,h,w))) + loss += edge_loss + + return loss + + +class FlowSimpleLoss(nn.Module): + def __init__(self): + super().__init__() + self.l1_criterion = nn.L1Loss() + + def forward(self, pred_flows, gt_flows): + # pred_flows: b t-1 2 h w + loss = 0 + h, w = pred_flows[0].shape[-2:] + h_orig, w_orig = gt_flows[0].shape[-2:] + pred_flows = [f.view(-1, 2, h, w) for f in pred_flows] + gt_flows = [f.view(-1, 2, h_orig, w_orig) for f in gt_flows] + + ds_factor = 1.0*h/h_orig + gt_flows = [F.interpolate(f, scale_factor=ds_factor, mode='area') * ds_factor for f in gt_flows] + for i in range(len(pred_flows)): + loss += self.l1_criterion(pred_flows[i], gt_flows[i]) + + return loss \ No newline at end of file diff --git a/propainter/model/modules/flow_loss_utils.py b/propainter/model/modules/flow_loss_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6e465c0605df760920b5cfc7f9079fadb74fbec1 --- /dev/null +++ b/propainter/model/modules/flow_loss_utils.py @@ -0,0 +1,142 @@ +import torch +import numpy as np +import torch.nn as nn +import torch.nn.functional as F + +def flow_warp(x, + flow, + interpolation='bilinear', + padding_mode='zeros', + align_corners=True): + """Warp an image or a feature map with optical flow. + Args: + x (Tensor): Tensor with size (n, c, h, w). + flow (Tensor): Tensor with size (n, h, w, 2). The last dimension is + a two-channel, denoting the width and height relative offsets. + Note that the values are not normalized to [-1, 1]. + interpolation (str): Interpolation mode: 'nearest' or 'bilinear'. + Default: 'bilinear'. + padding_mode (str): Padding mode: 'zeros' or 'border' or 'reflection'. + Default: 'zeros'. + align_corners (bool): Whether align corners. Default: True. + Returns: + Tensor: Warped image or feature map. + """ + if x.size()[-2:] != flow.size()[1:3]: + raise ValueError(f'The spatial sizes of input ({x.size()[-2:]}) and ' + f'flow ({flow.size()[1:3]}) are not the same.') + _, _, h, w = x.size() + # create mesh grid + device = flow.device + grid_y, grid_x = torch.meshgrid(torch.arange(0, h, device=device), torch.arange(0, w, device=device)) + grid = torch.stack((grid_x, grid_y), 2).type_as(x) # (w, h, 2) + grid.requires_grad = False + + grid_flow = grid + flow + # scale grid_flow to [-1,1] + grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0 + grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0 + grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim=3) + output = F.grid_sample(x, + grid_flow, + mode=interpolation, + padding_mode=padding_mode, + align_corners=align_corners) + return output + + +# def image_warp(image, flow): +# b, c, h, w = image.size() +# device = image.device +# flow = torch.cat([flow[:, 0:1, :, :] / ((w - 1.0) / 2.0), flow[:, 1:2, :, :] / ((h - 1.0) / 2.0)], dim=1) # normalize to [-1~1](from upper left to lower right +# flow = flow.permute(0, 2, 3, 1) # if you wanna use grid_sample function, the channel(band) shape of show must be in the last dimension +# x = np.linspace(-1, 1, w) +# y = np.linspace(-1, 1, h) +# X, Y = np.meshgrid(x, y) +# grid = torch.cat((torch.from_numpy(X.astype('float32')).unsqueeze(0).unsqueeze(3), +# torch.from_numpy(Y.astype('float32')).unsqueeze(0).unsqueeze(3)), 3).to(device) +# output = torch.nn.functional.grid_sample(image, grid + flow, mode='bilinear', padding_mode='zeros') +# return output + + +def length_sq(x): + return torch.sum(torch.square(x), dim=1, keepdim=True) + + +def fbConsistencyCheck(flow_fw, flow_bw, alpha1=0.01, alpha2=0.5): + flow_bw_warped = flow_warp(flow_bw, flow_fw.permute(0, 2, 3, 1)) # wb(wf(x)) + flow_fw_warped = flow_warp(flow_fw, flow_bw.permute(0, 2, 3, 1)) # wf(wb(x)) + flow_diff_fw = flow_fw + flow_bw_warped # wf + wb(wf(x)) + flow_diff_bw = flow_bw + flow_fw_warped # wb + wf(wb(x)) + + mag_sq_fw = length_sq(flow_fw) + length_sq(flow_bw_warped) # |wf| + |wb(wf(x))| + mag_sq_bw = length_sq(flow_bw) + length_sq(flow_fw_warped) # |wb| + |wf(wb(x))| + occ_thresh_fw = alpha1 * mag_sq_fw + alpha2 + occ_thresh_bw = alpha1 * mag_sq_bw + alpha2 + + fb_occ_fw = (length_sq(flow_diff_fw) > occ_thresh_fw).float() + fb_occ_bw = (length_sq(flow_diff_bw) > occ_thresh_bw).float() + + return fb_occ_fw, fb_occ_bw # fb_occ_fw -> frame2 area occluded by frame1, fb_occ_bw -> frame1 area occluded by frame2 + + +def rgb2gray(image): + gray_image = image[:, 0] * 0.299 + image[:, 1] * 0.587 + 0.110 * image[:, 2] + gray_image = gray_image.unsqueeze(1) + return gray_image + + +def ternary_transform(image, max_distance=1): + device = image.device + patch_size = 2 * max_distance + 1 + intensities = rgb2gray(image) * 255 + out_channels = patch_size * patch_size + w = np.eye(out_channels).reshape(out_channels, 1, patch_size, patch_size) + weights = torch.from_numpy(w).float().to(device) + patches = F.conv2d(intensities, weights, stride=1, padding=1) + transf = patches - intensities + transf_norm = transf / torch.sqrt(0.81 + torch.square(transf)) + return transf_norm + + +def hamming_distance(t1, t2): + dist = torch.square(t1 - t2) + dist_norm = dist / (0.1 + dist) + dist_sum = torch.sum(dist_norm, dim=1, keepdim=True) + return dist_sum + + +def create_mask(mask, paddings): + """ + padding: [[top, bottom], [left, right]] + """ + shape = mask.shape + inner_height = shape[2] - (paddings[0][0] + paddings[0][1]) + inner_width = shape[3] - (paddings[1][0] + paddings[1][1]) + inner = torch.ones([inner_height, inner_width]) + + mask2d = F.pad(inner, pad=[paddings[1][0], paddings[1][1], paddings[0][0], paddings[0][1]]) + mask3d = mask2d.unsqueeze(0) + mask4d = mask3d.unsqueeze(0).repeat(shape[0], 1, 1, 1) + return mask4d.detach() + + +def ternary_loss2(frame1, warp_frame21, confMask, masks, max_distance=1): + """ + + Args: + frame1: torch tensor, with shape [b * t, c, h, w] + warp_frame21: torch tensor, with shape [b * t, c, h, w] + confMask: confidence mask, with shape [b * t, c, h, w] + masks: torch tensor, with shape [b * t, c, h, w] + max_distance: maximum distance. + + Returns: ternary loss + + """ + t1 = ternary_transform(frame1) + t21 = ternary_transform(warp_frame21) + dist = hamming_distance(t1, t21) + loss = torch.mean(dist * confMask * masks) / torch.mean(masks) + return loss + diff --git a/propainter/model/modules/sparse_transformer.py b/propainter/model/modules/sparse_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..11028ffe05a0f59e9d222b0a18f92b1fde12007b --- /dev/null +++ b/propainter/model/modules/sparse_transformer.py @@ -0,0 +1,344 @@ +import math +from functools import reduce +import torch +import torch.nn as nn +import torch.nn.functional as F + +class SoftSplit(nn.Module): + def __init__(self, channel, hidden, kernel_size, stride, padding): + super(SoftSplit, self).__init__() + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.t2t = nn.Unfold(kernel_size=kernel_size, + stride=stride, + padding=padding) + c_in = reduce((lambda x, y: x * y), kernel_size) * channel + self.embedding = nn.Linear(c_in, hidden) + + def forward(self, x, b, output_size): + f_h = int((output_size[0] + 2 * self.padding[0] - + (self.kernel_size[0] - 1) - 1) / self.stride[0] + 1) + f_w = int((output_size[1] + 2 * self.padding[1] - + (self.kernel_size[1] - 1) - 1) / self.stride[1] + 1) + + feat = self.t2t(x) + feat = feat.permute(0, 2, 1) + # feat shape [b*t, num_vec, ks*ks*c] + feat = self.embedding(feat) + # feat shape after embedding [b, t*num_vec, hidden] + feat = feat.view(b, -1, f_h, f_w, feat.size(2)) + return feat + + +class SoftComp(nn.Module): + def __init__(self, channel, hidden, kernel_size, stride, padding): + super(SoftComp, self).__init__() + self.relu = nn.LeakyReLU(0.2, inplace=True) + c_out = reduce((lambda x, y: x * y), kernel_size) * channel + self.embedding = nn.Linear(hidden, c_out) + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.bias_conv = nn.Conv2d(channel, + channel, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x, t, output_size): + b_, _, _, _, c_ = x.shape + x = x.view(b_, -1, c_) + feat = self.embedding(x) + b, _, c = feat.size() + feat = feat.view(b * t, -1, c).permute(0, 2, 1) + feat = F.fold(feat, + output_size=output_size, + kernel_size=self.kernel_size, + stride=self.stride, + padding=self.padding) + feat = self.bias_conv(feat) + return feat + + +class FusionFeedForward(nn.Module): + def __init__(self, dim, hidden_dim=1960, t2t_params=None): + super(FusionFeedForward, self).__init__() + # We set hidden_dim as a default to 1960 + self.fc1 = nn.Sequential(nn.Linear(dim, hidden_dim)) + self.fc2 = nn.Sequential(nn.GELU(), nn.Linear(hidden_dim, dim)) + assert t2t_params is not None + self.t2t_params = t2t_params + self.kernel_shape = reduce((lambda x, y: x * y), t2t_params['kernel_size']) # 49 + + def forward(self, x, output_size): + n_vecs = 1 + for i, d in enumerate(self.t2t_params['kernel_size']): + n_vecs *= int((output_size[i] + 2 * self.t2t_params['padding'][i] - + (d - 1) - 1) / self.t2t_params['stride'][i] + 1) + + x = self.fc1(x) + b, n, c = x.size() + normalizer = x.new_ones(b, n, self.kernel_shape).view(-1, n_vecs, self.kernel_shape).permute(0, 2, 1) + normalizer = F.fold(normalizer, + output_size=output_size, + kernel_size=self.t2t_params['kernel_size'], + padding=self.t2t_params['padding'], + stride=self.t2t_params['stride']) + + x = F.fold(x.view(-1, n_vecs, c).permute(0, 2, 1), + output_size=output_size, + kernel_size=self.t2t_params['kernel_size'], + padding=self.t2t_params['padding'], + stride=self.t2t_params['stride']) + + x = F.unfold(x / normalizer, + kernel_size=self.t2t_params['kernel_size'], + padding=self.t2t_params['padding'], + stride=self.t2t_params['stride']).permute( + 0, 2, 1).contiguous().view(b, n, c) + x = self.fc2(x) + return x + + +def window_partition(x, window_size, n_head): + """ + Args: + x: shape is (B, T, H, W, C) + window_size (tuple[int]): window size + Returns: + windows: (B, num_windows_h, num_windows_w, n_head, T, window_size, window_size, C//n_head) + """ + B, T, H, W, C = x.shape + x = x.view(B, T, H // window_size[0], window_size[0], W // window_size[1], window_size[1], n_head, C//n_head) + windows = x.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous() + return windows + +class SparseWindowAttention(nn.Module): + def __init__(self, dim, n_head, window_size, pool_size=(4,4), qkv_bias=True, attn_drop=0., proj_drop=0., + pooling_token=True): + super().__init__() + assert dim % n_head == 0 + # key, query, value projections for all heads + self.key = nn.Linear(dim, dim, qkv_bias) + self.query = nn.Linear(dim, dim, qkv_bias) + self.value = nn.Linear(dim, dim, qkv_bias) + # regularization + self.attn_drop = nn.Dropout(attn_drop) + self.proj_drop = nn.Dropout(proj_drop) + # output projection + self.proj = nn.Linear(dim, dim) + self.n_head = n_head + self.window_size = window_size + self.pooling_token = pooling_token + if self.pooling_token: + ks, stride = pool_size, pool_size + self.pool_layer = nn.Conv2d(dim, dim, kernel_size=ks, stride=stride, padding=(0, 0), groups=dim) + self.pool_layer.weight.data.fill_(1. / (pool_size[0] * pool_size[1])) + self.pool_layer.bias.data.fill_(0) + # self.expand_size = tuple(i // 2 for i in window_size) + self.expand_size = tuple((i + 1) // 2 for i in window_size) + + if any(i > 0 for i in self.expand_size): + # get mask for rolled k and rolled v + mask_tl = torch.ones(self.window_size[0], self.window_size[1]) + mask_tl[:-self.expand_size[0], :-self.expand_size[1]] = 0 + mask_tr = torch.ones(self.window_size[0], self.window_size[1]) + mask_tr[:-self.expand_size[0], self.expand_size[1]:] = 0 + mask_bl = torch.ones(self.window_size[0], self.window_size[1]) + mask_bl[self.expand_size[0]:, :-self.expand_size[1]] = 0 + mask_br = torch.ones(self.window_size[0], self.window_size[1]) + mask_br[self.expand_size[0]:, self.expand_size[1]:] = 0 + masrool_k = torch.stack((mask_tl, mask_tr, mask_bl, mask_br), 0).flatten(0) + self.register_buffer("valid_ind_rolled", masrool_k.nonzero(as_tuple=False).view(-1)) + + self.max_pool = nn.MaxPool2d(window_size, window_size, (0, 0)) + + + def forward(self, x, mask=None, T_ind=None, attn_mask=None): + b, t, h, w, c = x.shape # 20 36 + w_h, w_w = self.window_size[0], self.window_size[1] + c_head = c // self.n_head + n_wh = math.ceil(h / self.window_size[0]) + n_ww = math.ceil(w / self.window_size[1]) + new_h = n_wh * self.window_size[0] # 20 + new_w = n_ww * self.window_size[1] # 36 + pad_r = new_w - w + pad_b = new_h - h + # reverse order + if pad_r > 0 or pad_b > 0: + x = F.pad(x,(0, 0, 0, pad_r, 0, pad_b, 0, 0), mode='constant', value=0) + mask = F.pad(mask,(0, 0, 0, pad_r, 0, pad_b, 0, 0), mode='constant', value=0) + + # calculate query, key, values for all heads in batch and move head forward to be the batch dim + q = self.query(x) + k = self.key(x) + v = self.value(x) + win_q = window_partition(q.contiguous(), self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head) + win_k = window_partition(k.contiguous(), self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head) + win_v = window_partition(v.contiguous(), self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head) + # roll_k and roll_v + if any(i > 0 for i in self.expand_size): + (k_tl, v_tl) = map(lambda a: torch.roll(a, shifts=(-self.expand_size[0], -self.expand_size[1]), dims=(2, 3)), (k, v)) + (k_tr, v_tr) = map(lambda a: torch.roll(a, shifts=(-self.expand_size[0], self.expand_size[1]), dims=(2, 3)), (k, v)) + (k_bl, v_bl) = map(lambda a: torch.roll(a, shifts=(self.expand_size[0], -self.expand_size[1]), dims=(2, 3)), (k, v)) + (k_br, v_br) = map(lambda a: torch.roll(a, shifts=(self.expand_size[0], self.expand_size[1]), dims=(2, 3)), (k, v)) + + (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map( + lambda a: window_partition(a, self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head), + (k_tl, k_tr, k_bl, k_br)) + (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map( + lambda a: window_partition(a, self.window_size, self.n_head).view(b, n_wh*n_ww, self.n_head, t, w_h*w_w, c_head), + (v_tl, v_tr, v_bl, v_br)) + rool_k = torch.cat((k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows), 4).contiguous() + rool_v = torch.cat((v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows), 4).contiguous() # [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head] + # mask out tokens in current window + rool_k = rool_k[:, :, :, :, self.valid_ind_rolled] + rool_v = rool_v[:, :, :, :, self.valid_ind_rolled] + roll_N = rool_k.shape[4] + rool_k = rool_k.view(b, n_wh*n_ww, self.n_head, t, roll_N, c // self.n_head) + rool_v = rool_v.view(b, n_wh*n_ww, self.n_head, t, roll_N, c // self.n_head) + win_k = torch.cat((win_k, rool_k), dim=4) + win_v = torch.cat((win_v, rool_v), dim=4) + else: + win_k = win_k + win_v = win_v + + # pool_k and pool_v + if self.pooling_token: + pool_x = self.pool_layer(x.view(b*t, new_h, new_w, c).permute(0,3,1,2)) + _, _, p_h, p_w = pool_x.shape + pool_x = pool_x.permute(0,2,3,1).view(b, t, p_h, p_w, c) + # pool_k + pool_k = self.key(pool_x).unsqueeze(1).repeat(1, n_wh*n_ww, 1, 1, 1, 1) # [b, n_wh*n_ww, t, p_h, p_w, c] + pool_k = pool_k.view(b, n_wh*n_ww, t, p_h, p_w, self.n_head, c_head).permute(0,1,5,2,3,4,6) + pool_k = pool_k.contiguous().view(b, n_wh*n_ww, self.n_head, t, p_h*p_w, c_head) + win_k = torch.cat((win_k, pool_k), dim=4) + # pool_v + pool_v = self.value(pool_x).unsqueeze(1).repeat(1, n_wh*n_ww, 1, 1, 1, 1) # [b, n_wh*n_ww, t, p_h, p_w, c] + pool_v = pool_v.view(b, n_wh*n_ww, t, p_h, p_w, self.n_head, c_head).permute(0,1,5,2,3,4,6) + pool_v = pool_v.contiguous().view(b, n_wh*n_ww, self.n_head, t, p_h*p_w, c_head) + win_v = torch.cat((win_v, pool_v), dim=4) + + # [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head] + out = torch.zeros_like(win_q) + l_t = mask.size(1) + + mask = self.max_pool(mask.view(b * l_t, new_h, new_w)) + mask = mask.view(b, l_t, n_wh*n_ww) + mask = torch.sum(mask, dim=1) # [b, n_wh*n_ww] + for i in range(win_q.shape[0]): + ### For masked windows + mask_ind_i = mask[i].nonzero(as_tuple=False).view(-1) + # mask out quary in current window + # [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head] + mask_n = len(mask_ind_i) + if mask_n > 0: + win_q_t = win_q[i, mask_ind_i].view(mask_n, self.n_head, t*w_h*w_w, c_head) + win_k_t = win_k[i, mask_ind_i] + win_v_t = win_v[i, mask_ind_i] + # mask out key and value + if T_ind is not None: + # key [n_wh*n_ww, n_head, t, w_h*w_w, c_head] + win_k_t = win_k_t[:, :, T_ind.view(-1)].view(mask_n, self.n_head, -1, c_head) + # value + win_v_t = win_v_t[:, :, T_ind.view(-1)].view(mask_n, self.n_head, -1, c_head) + else: + win_k_t = win_k_t.view(n_wh*n_ww, self.n_head, t*w_h*w_w, c_head) + win_v_t = win_v_t.view(n_wh*n_ww, self.n_head, t*w_h*w_w, c_head) + + att_t = (win_q_t @ win_k_t.transpose(-2, -1)) * (1.0 / math.sqrt(win_q_t.size(-1))) + att_t = F.softmax(att_t, dim=-1) + att_t = self.attn_drop(att_t) + y_t = att_t @ win_v_t + + out[i, mask_ind_i] = y_t.view(-1, self.n_head, t, w_h*w_w, c_head) + + ### For unmasked windows + unmask_ind_i = (mask[i] == 0).nonzero(as_tuple=False).view(-1) + # mask out quary in current window + # [b, n_wh*n_ww, n_head, t, w_h*w_w, c_head] + win_q_s = win_q[i, unmask_ind_i] + win_k_s = win_k[i, unmask_ind_i, :, :, :w_h*w_w] + win_v_s = win_v[i, unmask_ind_i, :, :, :w_h*w_w] + + att_s = (win_q_s @ win_k_s.transpose(-2, -1)) * (1.0 / math.sqrt(win_q_s.size(-1))) + att_s = F.softmax(att_s, dim=-1) + att_s = self.attn_drop(att_s) + y_s = att_s @ win_v_s + out[i, unmask_ind_i] = y_s + + # re-assemble all head outputs side by side + out = out.view(b, n_wh, n_ww, self.n_head, t, w_h, w_w, c_head) + out = out.permute(0, 4, 1, 5, 2, 6, 3, 7).contiguous().view(b, t, new_h, new_w, c) + + + if pad_r > 0 or pad_b > 0: + out = out[:, :, :h, :w, :] + + # output projection + out = self.proj_drop(self.proj(out)) + return out + + +class TemporalSparseTransformer(nn.Module): + def __init__(self, dim, n_head, window_size, pool_size, + norm_layer=nn.LayerNorm, t2t_params=None): + super().__init__() + self.window_size = window_size + self.attention = SparseWindowAttention(dim, n_head, window_size, pool_size) + self.norm1 = norm_layer(dim) + self.norm2 = norm_layer(dim) + self.mlp = FusionFeedForward(dim, t2t_params=t2t_params) + + def forward(self, x, fold_x_size, mask=None, T_ind=None): + """ + Args: + x: image tokens, shape [B T H W C] + fold_x_size: fold feature size, shape [60 108] + mask: mask tokens, shape [B T H W 1] + Returns: + out_tokens: shape [B T H W C] + """ + B, T, H, W, C = x.shape # 20 36 + + shortcut = x + x = self.norm1(x) + att_x = self.attention(x, mask, T_ind) + + # FFN + x = shortcut + att_x + y = self.norm2(x) + x = x + self.mlp(y.view(B, T * H * W, C), fold_x_size).view(B, T, H, W, C) + + return x + + +class TemporalSparseTransformerBlock(nn.Module): + def __init__(self, dim, n_head, window_size, pool_size, depths, t2t_params=None): + super().__init__() + blocks = [] + for i in range(depths): + blocks.append( + TemporalSparseTransformer(dim, n_head, window_size, pool_size, t2t_params=t2t_params) + ) + self.transformer = nn.Sequential(*blocks) + self.depths = depths + + def forward(self, x, fold_x_size, l_mask=None, t_dilation=2): + """ + Args: + x: image tokens, shape [B T H W C] + fold_x_size: fold feature size, shape [60 108] + l_mask: local mask tokens, shape [B T H W 1] + Returns: + out_tokens: shape [B T H W C] + """ + assert self.depths % t_dilation == 0, 'wrong t_dilation input.' + T = x.size(1) + T_ind = [torch.arange(i, T, t_dilation) for i in range(t_dilation)] * (self.depths // t_dilation) + + for i in range(0, self.depths): + x = self.transformer[i](x, fold_x_size, l_mask, T_ind[i]) + + return x diff --git a/propainter/model/modules/spectral_norm.py b/propainter/model/modules/spectral_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..f38c34e98c03caa28ce0b15a4083215fb7d8e9af --- /dev/null +++ b/propainter/model/modules/spectral_norm.py @@ -0,0 +1,288 @@ +""" +Spectral Normalization from https://arxiv.org/abs/1802.05957 +""" +import torch +from torch.nn.functional import normalize + + +class SpectralNorm(object): + # Invariant before and after each forward call: + # u = normalize(W @ v) + # NB: At initialization, this invariant is not enforced + + _version = 1 + + # At version 1: + # made `W` not a buffer, + # added `v` as a buffer, and + # made eval mode use `W = u @ W_orig @ v` rather than the stored `W`. + + def __init__(self, name='weight', n_power_iterations=1, dim=0, eps=1e-12): + self.name = name + self.dim = dim + if n_power_iterations <= 0: + raise ValueError( + 'Expected n_power_iterations to be positive, but ' + 'got n_power_iterations={}'.format(n_power_iterations)) + self.n_power_iterations = n_power_iterations + self.eps = eps + + def reshape_weight_to_matrix(self, weight): + weight_mat = weight + if self.dim != 0: + # permute dim to front + weight_mat = weight_mat.permute( + self.dim, + *[d for d in range(weight_mat.dim()) if d != self.dim]) + height = weight_mat.size(0) + return weight_mat.reshape(height, -1) + + def compute_weight(self, module, do_power_iteration): + # NB: If `do_power_iteration` is set, the `u` and `v` vectors are + # updated in power iteration **in-place**. This is very important + # because in `DataParallel` forward, the vectors (being buffers) are + # broadcast from the parallelized module to each module replica, + # which is a new module object created on the fly. And each replica + # runs its own spectral norm power iteration. So simply assigning + # the updated vectors to the module this function runs on will cause + # the update to be lost forever. And the next time the parallelized + # module is replicated, the same randomly initialized vectors are + # broadcast and used! + # + # Therefore, to make the change propagate back, we rely on two + # important behaviors (also enforced via tests): + # 1. `DataParallel` doesn't clone storage if the broadcast tensor + # is already on correct device; and it makes sure that the + # parallelized module is already on `device[0]`. + # 2. If the out tensor in `out=` kwarg has correct shape, it will + # just fill in the values. + # Therefore, since the same power iteration is performed on all + # devices, simply updating the tensors in-place will make sure that + # the module replica on `device[0]` will update the _u vector on the + # parallized module (by shared storage). + # + # However, after we update `u` and `v` in-place, we need to **clone** + # them before using them to normalize the weight. This is to support + # backproping through two forward passes, e.g., the common pattern in + # GAN training: loss = D(real) - D(fake). Otherwise, engine will + # complain that variables needed to do backward for the first forward + # (i.e., the `u` and `v` vectors) are changed in the second forward. + weight = getattr(module, self.name + '_orig') + u = getattr(module, self.name + '_u') + v = getattr(module, self.name + '_v') + weight_mat = self.reshape_weight_to_matrix(weight) + + if do_power_iteration: + with torch.no_grad(): + for _ in range(self.n_power_iterations): + # Spectral norm of weight equals to `u^T W v`, where `u` and `v` + # are the first left and right singular vectors. + # This power iteration produces approximations of `u` and `v`. + v = normalize(torch.mv(weight_mat.t(), u), + dim=0, + eps=self.eps, + out=v) + u = normalize(torch.mv(weight_mat, v), + dim=0, + eps=self.eps, + out=u) + if self.n_power_iterations > 0: + # See above on why we need to clone + u = u.clone() + v = v.clone() + + sigma = torch.dot(u, torch.mv(weight_mat, v)) + weight = weight / sigma + return weight + + def remove(self, module): + with torch.no_grad(): + weight = self.compute_weight(module, do_power_iteration=False) + delattr(module, self.name) + delattr(module, self.name + '_u') + delattr(module, self.name + '_v') + delattr(module, self.name + '_orig') + module.register_parameter(self.name, + torch.nn.Parameter(weight.detach())) + + def __call__(self, module, inputs): + setattr( + module, self.name, + self.compute_weight(module, do_power_iteration=module.training)) + + def _solve_v_and_rescale(self, weight_mat, u, target_sigma): + # Tries to returns a vector `v` s.t. `u = normalize(W @ v)` + # (the invariant at top of this class) and `u @ W @ v = sigma`. + # This uses pinverse in case W^T W is not invertible. + v = torch.chain_matmul(weight_mat.t().mm(weight_mat).pinverse(), + weight_mat.t(), u.unsqueeze(1)).squeeze(1) + return v.mul_(target_sigma / torch.dot(u, torch.mv(weight_mat, v))) + + @staticmethod + def apply(module, name, n_power_iterations, dim, eps): + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + raise RuntimeError( + "Cannot register two spectral_norm hooks on " + "the same parameter {}".format(name)) + + fn = SpectralNorm(name, n_power_iterations, dim, eps) + weight = module._parameters[name] + + with torch.no_grad(): + weight_mat = fn.reshape_weight_to_matrix(weight) + + h, w = weight_mat.size() + # randomly initialize `u` and `v` + u = normalize(weight.new_empty(h).normal_(0, 1), dim=0, eps=fn.eps) + v = normalize(weight.new_empty(w).normal_(0, 1), dim=0, eps=fn.eps) + + delattr(module, fn.name) + module.register_parameter(fn.name + "_orig", weight) + # We still need to assign weight back as fn.name because all sorts of + # things may assume that it exists, e.g., when initializing weights. + # However, we can't directly assign as it could be an nn.Parameter and + # gets added as a parameter. Instead, we register weight.data as a plain + # attribute. + setattr(module, fn.name, weight.data) + module.register_buffer(fn.name + "_u", u) + module.register_buffer(fn.name + "_v", v) + + module.register_forward_pre_hook(fn) + + module._register_state_dict_hook(SpectralNormStateDictHook(fn)) + module._register_load_state_dict_pre_hook( + SpectralNormLoadStateDictPreHook(fn)) + return fn + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormLoadStateDictPreHook(object): + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn): + self.fn = fn + + # For state_dict with version None, (assuming that it has gone through at + # least one training forward), we have + # + # u = normalize(W_orig @ v) + # W = W_orig / sigma, where sigma = u @ W_orig @ v + # + # To compute `v`, we solve `W_orig @ x = u`, and let + # v = x / (u @ W_orig @ x) * (W / W_orig). + def __call__(self, state_dict, prefix, local_metadata, strict, + missing_keys, unexpected_keys, error_msgs): + fn = self.fn + version = local_metadata.get('spectral_norm', + {}).get(fn.name + '.version', None) + if version is None or version < 1: + with torch.no_grad(): + weight_orig = state_dict[prefix + fn.name + '_orig'] + # weight = state_dict.pop(prefix + fn.name) + # sigma = (weight_orig / weight).mean() + weight_mat = fn.reshape_weight_to_matrix(weight_orig) + u = state_dict[prefix + fn.name + '_u'] + # v = fn._solve_v_and_rescale(weight_mat, u, sigma) + # state_dict[prefix + fn.name + '_v'] = v + + +# This is a top level class because Py2 pickle doesn't like inner class nor an +# instancemethod. +class SpectralNormStateDictHook(object): + # See docstring of SpectralNorm._version on the changes to spectral_norm. + def __init__(self, fn): + self.fn = fn + + def __call__(self, module, state_dict, prefix, local_metadata): + if 'spectral_norm' not in local_metadata: + local_metadata['spectral_norm'] = {} + key = self.fn.name + '.version' + if key in local_metadata['spectral_norm']: + raise RuntimeError( + "Unexpected key in metadata['spectral_norm']: {}".format(key)) + local_metadata['spectral_norm'][key] = self.fn._version + + +def spectral_norm(module, + name='weight', + n_power_iterations=1, + eps=1e-12, + dim=None): + r"""Applies spectral normalization to a parameter in the given module. + + .. math:: + \mathbf{W}_{SN} = \dfrac{\mathbf{W}}{\sigma(\mathbf{W})}, + \sigma(\mathbf{W}) = \max_{\mathbf{h}: \mathbf{h} \ne 0} \dfrac{\|\mathbf{W} \mathbf{h}\|_2}{\|\mathbf{h}\|_2} + + Spectral normalization stabilizes the training of discriminators (critics) + in Generative Adversarial Networks (GANs) by rescaling the weight tensor + with spectral norm :math:`\sigma` of the weight matrix calculated using + power iteration method. If the dimension of the weight tensor is greater + than 2, it is reshaped to 2D in power iteration method to get spectral + norm. This is implemented via a hook that calculates spectral norm and + rescales weight before every :meth:`~Module.forward` call. + + See `Spectral Normalization for Generative Adversarial Networks`_ . + + .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 + + Args: + module (nn.Module): containing module + name (str, optional): name of weight parameter + n_power_iterations (int, optional): number of power iterations to + calculate spectral norm + eps (float, optional): epsilon for numerical stability in + calculating norms + dim (int, optional): dimension corresponding to number of outputs, + the default is ``0``, except for modules that are instances of + ConvTranspose{1,2,3}d, when it is ``1`` + + Returns: + The original module with the spectral norm hook + + Example:: + + >>> m = spectral_norm(nn.Linear(20, 40)) + >>> m + Linear(in_features=20, out_features=40, bias=True) + >>> m.weight_u.size() + torch.Size([40]) + + """ + if dim is None: + if isinstance(module, + (torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, + torch.nn.ConvTranspose3d)): + dim = 1 + else: + dim = 0 + SpectralNorm.apply(module, name, n_power_iterations, dim, eps) + return module + + +def remove_spectral_norm(module, name='weight'): + r"""Removes the spectral normalization reparameterization from a module. + + Args: + module (Module): containing module + name (str, optional): name of weight parameter + + Example: + >>> m = spectral_norm(nn.Linear(40, 10)) + >>> remove_spectral_norm(m) + """ + for k, hook in module._forward_pre_hooks.items(): + if isinstance(hook, SpectralNorm) and hook.name == name: + hook.remove(module) + del module._forward_pre_hooks[k] + return module + + raise ValueError("spectral_norm of '{}' not found in {}".format( + name, module)) + + +def use_spectral_norm(module, use_sn=False): + if use_sn: + return spectral_norm(module) + return module \ No newline at end of file diff --git a/propainter/model/propainter.py b/propainter/model/propainter.py new file mode 100644 index 0000000000000000000000000000000000000000..71ec0c551b2eaa114aa1e74cd19bcf1542fb3b04 --- /dev/null +++ b/propainter/model/propainter.py @@ -0,0 +1,553 @@ +''' Towards An End-to-End Framework for Video Inpainting +''' + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +from einops import rearrange + +try: + from model.modules.base_module import BaseNetwork + from model.modules.sparse_transformer import TemporalSparseTransformerBlock, SoftSplit, SoftComp + from model.modules.spectral_norm import spectral_norm as _spectral_norm + from model.modules.flow_loss_utils import flow_warp + from model.modules.deformconv import ModulatedDeformConv2d + + from .misc import constant_init +except: + from propainter.model.modules.base_module import BaseNetwork + from propainter.model.modules.sparse_transformer import TemporalSparseTransformerBlock, SoftSplit, SoftComp + from propainter.model.modules.spectral_norm import spectral_norm as _spectral_norm + from propainter.model.modules.flow_loss_utils import flow_warp + from propainter.model.modules.deformconv import ModulatedDeformConv2d + + from propainter.model.misc import constant_init + + +def length_sq(x): + return torch.sum(torch.square(x), dim=1, keepdim=True) + +def fbConsistencyCheck(flow_fw, flow_bw, alpha1=0.01, alpha2=0.5): #debug + flow_bw_warped = flow_warp(flow_bw, flow_fw.permute(0, 2, 3, 1)) # wb(wf(x)) + flow_diff_fw = flow_fw + flow_bw_warped # wf + wb(wf(x)) + + mag_sq_fw = length_sq(flow_fw) + length_sq(flow_bw_warped) # |wf| + |wb(wf(x))| + occ_thresh_fw = alpha1 * mag_sq_fw + alpha2 + + # fb_valid_fw = (length_sq(flow_diff_fw) < occ_thresh_fw).float() + fb_valid_fw = (length_sq(flow_diff_fw) < occ_thresh_fw).to(flow_fw) + return fb_valid_fw + + +class DeformableAlignment(ModulatedDeformConv2d): + """Second-order deformable alignment module.""" + def __init__(self, *args, **kwargs): + # self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 10) + self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 3) + + super(DeformableAlignment, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Sequential( + nn.Conv2d(2*self.out_channels + 2 + 1 + 2, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1), + ) + self.init_offset() + + def init_offset(self): + constant_init(self.conv_offset[-1], val=0, bias=0) + + def forward(self, x, cond_feat, flow): + out = self.conv_offset(cond_feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + + # offset + offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1)) + offset = offset + flow.flip(1).repeat(1, offset.size(1) // 2, 1, 1) + + # mask + mask = torch.sigmoid(mask) + + return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, + self.stride, self.padding, + self.dilation, mask) + + +class BidirectionalPropagation(nn.Module): + def __init__(self, channel, learnable=True): + super(BidirectionalPropagation, self).__init__() + self.deform_align = nn.ModuleDict() + self.backbone = nn.ModuleDict() + self.channel = channel + self.prop_list = ['backward_1', 'forward_1'] + self.learnable = learnable + + if self.learnable: + for i, module in enumerate(self.prop_list): + self.deform_align[module] = DeformableAlignment( + channel, channel, 3, padding=1, deform_groups=16) + + self.backbone[module] = nn.Sequential( + nn.Conv2d(2*channel+2, channel, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(channel, channel, 3, 1, 1), + ) + + self.fuse = nn.Sequential( + nn.Conv2d(2*channel+2, channel, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.2, inplace=True), + nn.Conv2d(channel, channel, 3, 1, 1), + ) + + def binary_mask(self, mask, th=0.1): + mask[mask>th] = 1 + mask[mask<=th] = 0 + # return mask.float() + return mask.to(mask) + + def forward(self, x, flows_forward, flows_backward, mask, interpolation='bilinear', direction='forward'): + """ + x shape : [b, t, c, h, w] + return [b, t, c, h, w] + """ + + # For backward warping + # pred_flows_forward for backward feature propagation + # pred_flows_backward for forward feature propagation + b, t, c, h, w = x.shape + feats, masks = {}, {} + feats['input'] = [x[:, i, :, :, :] for i in range(0, t)] + masks['input'] = [mask[:, i, :, :, :] for i in range(0, t)] + + prop_list = ['backward_1', 'forward_1'] + cache_list = ['input'] + prop_list + + for p_i, module_name in enumerate(prop_list): + feats[module_name] = [] + masks[module_name] = [] + + if 'backward' in module_name: + frame_idx = range(0, t) + frame_idx = frame_idx[::-1] + flow_idx = frame_idx + flows_for_prop = flows_forward + flows_for_check = flows_backward + else: + frame_idx = range(0, t) + flow_idx = range(-1, t - 1) + flows_for_prop = flows_backward + flows_for_check = flows_forward + + len_frames_idx = len(frame_idx) + for i, idx in enumerate(frame_idx): + feat_current = feats[cache_list[p_i]][idx] + mask_current = masks[cache_list[p_i]][idx] + + if i == 0: + feat_prop = feat_current + mask_prop = mask_current + else: + flow_prop = flows_for_prop[:, flow_idx[i], :, :, :] + flow_check = flows_for_check[:, flow_idx[i], :, :, :] + flow_vaild_mask = fbConsistencyCheck(flow_prop, flow_check) + feat_warped = flow_warp(feat_prop, flow_prop.permute(0, 2, 3, 1), interpolation) + feat_warped = torch.clamp(feat_warped, min=-1.0, max=1.0) + + if self.learnable: + cond = torch.cat([feat_current, feat_warped, flow_prop, flow_vaild_mask, mask_current], dim=1) + feat_prop = self.deform_align[module_name](feat_prop, cond, flow_prop) + mask_prop = mask_current + else: + mask_prop_valid = flow_warp(mask_prop, flow_prop.permute(0, 2, 3, 1)) + mask_prop_valid = self.binary_mask(mask_prop_valid) + + union_vaild_mask = self.binary_mask(mask_current*flow_vaild_mask*(1-mask_prop_valid)) + feat_prop = union_vaild_mask * feat_warped + (1-union_vaild_mask) * feat_current + # update mask + mask_prop = self.binary_mask(mask_current*(1-(flow_vaild_mask*(1-mask_prop_valid)))) + + + # refine + if self.learnable: + feat = torch.cat([feat_current, feat_prop, mask_current], dim=1) + feat_prop = feat_prop + self.backbone[module_name](feat) + # feat_prop = self.backbone[module_name](feat_prop) + + feats[module_name].append(feat_prop) + masks[module_name].append(mask_prop) + + + # end for + if 'backward' in module_name: + feats[module_name] = feats[module_name][::-1] + masks[module_name] = masks[module_name][::-1] + + outputs_b = torch.stack(feats['backward_1'], dim=1).view(-1, c, h, w) + outputs_f = torch.stack(feats['forward_1'], dim=1).view(-1, c, h, w) + + if self.learnable: + mask_in = mask.view(-1, 2, h, w) + masks_b, masks_f = None, None + outputs = self.fuse(torch.cat([outputs_b, outputs_f, mask_in], dim=1)) + x.view(-1, c, h, w) + else: + if direction == 'forward': + masks_b = torch.stack(masks['backward_1'], dim=1) + masks_f = torch.stack(masks['forward_1'], dim=1) + outputs = outputs_f + else: + masks_b = torch.stack(masks['backward_1'], dim=1) + masks_f = torch.stack(masks['forward_1'], dim=1) + outputs = outputs_b + return outputs_b.view(b, -1, c, h, w), outputs_f.view(b, -1, c, h, w), \ + outputs.view(b, -1, c, h, w), masks_b + + return outputs_b.view(b, -1, c, h, w), outputs_f.view(b, -1, c, h, w), \ + outputs.view(b, -1, c, h, w), masks_f + + +class Encoder(nn.Module): + def __init__(self): + super(Encoder, self).__init__() + self.group = [1, 2, 4, 8, 1] + self.layers = nn.ModuleList([ + nn.Conv2d(5, 64, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(256, 384, kernel_size=3, stride=1, padding=1, groups=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(640, 512, kernel_size=3, stride=1, padding=1, groups=2), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(768, 384, kernel_size=3, stride=1, padding=1, groups=4), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(640, 256, kernel_size=3, stride=1, padding=1, groups=8), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(512, 128, kernel_size=3, stride=1, padding=1, groups=1), + nn.LeakyReLU(0.2, inplace=True) + ]) + + def forward(self, x): + bt, c, _, _ = x.size() + # h, w = h//4, w//4 + out = x + for i, layer in enumerate(self.layers): + if i == 8: + x0 = out + _, _, h, w = x0.size() + if i > 8 and i % 2 == 0: + g = self.group[(i - 8) // 2] + x = x0.view(bt, g, -1, h, w) + o = out.view(bt, g, -1, h, w) + out = torch.cat([x, o], 2).view(bt, -1, h, w) + out = layer(out) + return out + + +class deconv(nn.Module): + def __init__(self, + input_channel, + output_channel, + kernel_size=3, + padding=0): + super().__init__() + self.conv = nn.Conv2d(input_channel, + output_channel, + kernel_size=kernel_size, + stride=1, + padding=padding) + + def forward(self, x): + x = F.interpolate(x, + scale_factor=2, + mode='bilinear', + align_corners=True) + return self.conv(x) + + +class InpaintGenerator(BaseNetwork): + def __init__(self, init_weights=True, model_path=None): + super(InpaintGenerator, self).__init__() + channel = 128 + hidden = 512 + + # encoder + self.encoder = Encoder() + + # decoder + self.decoder = nn.Sequential( + deconv(channel, 128, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(128, 64, kernel_size=3, stride=1, padding=1), + nn.LeakyReLU(0.2, inplace=True), + deconv(64, 64, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1)) + + # soft split and soft composition + kernel_size = (7, 7) + padding = (3, 3) + stride = (3, 3) + t2t_params = { + 'kernel_size': kernel_size, + 'stride': stride, + 'padding': padding + } + self.ss = SoftSplit(channel, hidden, kernel_size, stride, padding) + self.sc = SoftComp(channel, hidden, kernel_size, stride, padding) + self.max_pool = nn.MaxPool2d(kernel_size, stride, padding) + + # feature propagation module + self.img_prop_module = BidirectionalPropagation(3, learnable=False) + self.feat_prop_module = BidirectionalPropagation(128, learnable=True) + + + depths = 8 + num_heads = 4 + window_size = (5, 9) + pool_size = (4, 4) + self.transformers = TemporalSparseTransformerBlock(dim=hidden, + n_head=num_heads, + window_size=window_size, + pool_size=pool_size, + depths=depths, + t2t_params=t2t_params) + if init_weights: + self.init_weights() + + + if model_path is not None: + # print('Pretrained ProPainter has loaded...') + ckpt = torch.load(model_path, map_location='cpu') + self.load_state_dict(ckpt, strict=True) + + # print network parameter number + # self.print_network() + + def img_propagation(self, masked_frames, completed_flows, masks, interpolation='nearest', direction = 'forward'): + _, _, prop_frames, updated_masks = self.img_prop_module(masked_frames, completed_flows[0], completed_flows[1], masks, interpolation, direction) + return prop_frames, updated_masks + + def forward(self, masked_frames, completed_flows, masks_in, masks_updated, num_local_frames, interpolation='bilinear', t_dilation=2): + """ + Args: + masks_in: original mask + masks_updated: updated mask after image propagation + """ + + l_t = num_local_frames + b, t, _, ori_h, ori_w = masked_frames.size() + + # extracting features + enc_feat = self.encoder(torch.cat([masked_frames.view(b * t, 3, ori_h, ori_w), + masks_in.view(b * t, 1, ori_h, ori_w), + masks_updated.view(b * t, 1, ori_h, ori_w)], dim=1)) + _, c, h, w = enc_feat.size() + local_feat = enc_feat.view(b, t, c, h, w)[:, :l_t, ...] + ref_feat = enc_feat.view(b, t, c, h, w)[:, l_t:, ...] + fold_feat_size = (h, w) + + ds_flows_f = F.interpolate(completed_flows[0].view(-1, 2, ori_h, ori_w), scale_factor=1/4, mode='bilinear', align_corners=False).view(b, l_t-1, 2, h, w)/4.0 + ds_flows_b = F.interpolate(completed_flows[1].view(-1, 2, ori_h, ori_w), scale_factor=1/4, mode='bilinear', align_corners=False).view(b, l_t-1, 2, h, w)/4.0 + ds_mask_in = F.interpolate(masks_in.reshape(-1, 1, ori_h, ori_w), scale_factor=1/4, mode='nearest').view(b, t, 1, h, w) + ds_mask_in_local = ds_mask_in[:, :l_t] + ds_mask_updated_local = F.interpolate(masks_updated[:,:l_t].reshape(-1, 1, ori_h, ori_w), scale_factor=1/4, mode='nearest').view(b, l_t, 1, h, w) + + + if self.training: + mask_pool_l = self.max_pool(ds_mask_in.view(-1, 1, h, w)) + mask_pool_l = mask_pool_l.view(b, t, 1, mask_pool_l.size(-2), mask_pool_l.size(-1)) + else: + mask_pool_l = self.max_pool(ds_mask_in_local.view(-1, 1, h, w)) + mask_pool_l = mask_pool_l.view(b, l_t, 1, mask_pool_l.size(-2), mask_pool_l.size(-1)) + + + prop_mask_in = torch.cat([ds_mask_in_local, ds_mask_updated_local], dim=2) + _, _, local_feat, _ = self.feat_prop_module(local_feat, ds_flows_f, ds_flows_b, prop_mask_in, interpolation) + enc_feat = torch.cat((local_feat, ref_feat), dim=1) + + trans_feat = self.ss(enc_feat.view(-1, c, h, w), b, fold_feat_size) + mask_pool_l = rearrange(mask_pool_l, 'b t c h w -> b t h w c').contiguous() + trans_feat = self.transformers(trans_feat, fold_feat_size, mask_pool_l, t_dilation=t_dilation) + trans_feat = self.sc(trans_feat, t, fold_feat_size) + trans_feat = trans_feat.view(b, t, -1, h, w) + + enc_feat = enc_feat + trans_feat + + if self.training: + output = self.decoder(enc_feat.view(-1, c, h, w)) + output = torch.tanh(output).view(b, t, 3, ori_h, ori_w) + else: + output = self.decoder(enc_feat[:, :l_t].view(-1, c, h, w)) + output = torch.tanh(output).view(b, l_t, 3, ori_h, ori_w) + + return output + + +# ###################################################################### +# Discriminator for Temporal Patch GAN +# ###################################################################### +class Discriminator(BaseNetwork): + def __init__(self, + in_channels=3, + use_sigmoid=False, + use_spectral_norm=True, + init_weights=True): + super(Discriminator, self).__init__() + self.use_sigmoid = use_sigmoid + nf = 32 + + self.conv = nn.Sequential( + spectral_norm( + nn.Conv3d(in_channels=in_channels, + out_channels=nf * 1, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=1, + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(64, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 1, + nf * 2, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(128, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 2, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 4, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 4, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv3d(nf * 4, + nf * 4, + kernel_size=(3, 5, 5), + stride=(1, 2, 2), + padding=(1, 2, 2))) + + if init_weights: + self.init_weights() + + def forward(self, xs): + # T, C, H, W = xs.shape (old) + # B, T, C, H, W (new) + xs_t = torch.transpose(xs, 1, 2) + feat = self.conv(xs_t) + if self.use_sigmoid: + feat = torch.sigmoid(feat) + out = torch.transpose(feat, 1, 2) # B, T, C, H, W + return out + + +class Discriminator_2D(BaseNetwork): + def __init__(self, + in_channels=3, + use_sigmoid=False, + use_spectral_norm=True, + init_weights=True): + super(Discriminator_2D, self).__init__() + self.use_sigmoid = use_sigmoid + nf = 32 + + self.conv = nn.Sequential( + spectral_norm( + nn.Conv3d(in_channels=in_channels, + out_channels=nf * 1, + kernel_size=(1, 5, 5), + stride=(1, 2, 2), + padding=(0, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(64, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 1, + nf * 2, + kernel_size=(1, 5, 5), + stride=(1, 2, 2), + padding=(0, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(128, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 2, + nf * 4, + kernel_size=(1, 5, 5), + stride=(1, 2, 2), + padding=(0, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 4, + nf * 4, + kernel_size=(1, 5, 5), + stride=(1, 2, 2), + padding=(0, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + spectral_norm( + nn.Conv3d(nf * 4, + nf * 4, + kernel_size=(1, 5, 5), + stride=(1, 2, 2), + padding=(0, 2, 2), + bias=not use_spectral_norm), use_spectral_norm), + # nn.InstanceNorm2d(256, track_running_stats=False), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv3d(nf * 4, + nf * 4, + kernel_size=(1, 5, 5), + stride=(1, 2, 2), + padding=(0, 2, 2))) + + if init_weights: + self.init_weights() + + def forward(self, xs): + # T, C, H, W = xs.shape (old) + # B, T, C, H, W (new) + xs_t = torch.transpose(xs, 1, 2) + feat = self.conv(xs_t) + if self.use_sigmoid: + feat = torch.sigmoid(feat) + out = torch.transpose(feat, 1, 2) # B, T, C, H, W + return out + +def spectral_norm(module, mode=True): + if mode: + return _spectral_norm(module) + return module diff --git a/propainter/model/recurrent_flow_completion.py b/propainter/model/recurrent_flow_completion.py new file mode 100644 index 0000000000000000000000000000000000000000..c1dcc422e42f4e86e94f6eba2ac962a26f4b1084 --- /dev/null +++ b/propainter/model/recurrent_flow_completion.py @@ -0,0 +1,352 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision + +try: + from model.modules.deformconv import ModulatedDeformConv2d + from .misc import constant_init +except: + from propainter.model.modules.deformconv import ModulatedDeformConv2d + from propainter.model.misc import constant_init + + +class SecondOrderDeformableAlignment(ModulatedDeformConv2d): + """Second-order deformable alignment module.""" + def __init__(self, *args, **kwargs): + self.max_residue_magnitude = kwargs.pop('max_residue_magnitude', 5) + + super(SecondOrderDeformableAlignment, self).__init__(*args, **kwargs) + + self.conv_offset = nn.Sequential( + nn.Conv2d(3 * self.out_channels, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, self.out_channels, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(self.out_channels, 27 * self.deform_groups, 3, 1, 1), + ) + self.init_offset() + + def init_offset(self): + constant_init(self.conv_offset[-1], val=0, bias=0) + + def forward(self, x, extra_feat): + out = self.conv_offset(extra_feat) + o1, o2, mask = torch.chunk(out, 3, dim=1) + + # offset + offset = self.max_residue_magnitude * torch.tanh(torch.cat((o1, o2), dim=1)) + offset_1, offset_2 = torch.chunk(offset, 2, dim=1) + offset = torch.cat([offset_1, offset_2], dim=1) + + # mask + mask = torch.sigmoid(mask) + + return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, + self.stride, self.padding, + self.dilation, mask) + +class BidirectionalPropagation(nn.Module): + def __init__(self, channel): + super(BidirectionalPropagation, self).__init__() + modules = ['backward_', 'forward_'] + self.deform_align = nn.ModuleDict() + self.backbone = nn.ModuleDict() + self.channel = channel + + for i, module in enumerate(modules): + self.deform_align[module] = SecondOrderDeformableAlignment( + 2 * channel, channel, 3, padding=1, deform_groups=16) + + self.backbone[module] = nn.Sequential( + nn.Conv2d((2 + i) * channel, channel, 3, 1, 1), + nn.LeakyReLU(negative_slope=0.1, inplace=True), + nn.Conv2d(channel, channel, 3, 1, 1), + ) + + self.fusion = nn.Conv2d(2 * channel, channel, 1, 1, 0) + + def forward(self, x): + """ + x shape : [b, t, c, h, w] + return [b, t, c, h, w] + """ + b, t, c, h, w = x.shape + feats = {} + feats['spatial'] = [x[:, i, :, :, :] for i in range(0, t)] + + for module_name in ['backward_', 'forward_']: + + feats[module_name] = [] + + frame_idx = range(0, t) + mapping_idx = list(range(0, len(feats['spatial']))) + mapping_idx += mapping_idx[::-1] + + if 'backward' in module_name: + frame_idx = frame_idx[::-1] + + feat_prop = x.new_zeros(b, self.channel, h, w) + for i, idx in enumerate(frame_idx): + feat_current = feats['spatial'][mapping_idx[idx]] + if i > 0: + cond_n1 = feat_prop + + # initialize second-order features + feat_n2 = torch.zeros_like(feat_prop) + cond_n2 = torch.zeros_like(cond_n1) + if i > 1: # second-order features + feat_n2 = feats[module_name][-2] + cond_n2 = feat_n2 + + cond = torch.cat([cond_n1, feat_current, cond_n2], dim=1) # condition information, cond(flow warped 1st/2nd feature) + feat_prop = torch.cat([feat_prop, feat_n2], dim=1) # two order feat_prop -1 & -2 + feat_prop = self.deform_align[module_name](feat_prop, cond) + + # fuse current features + feat = [feat_current] + \ + [feats[k][idx] for k in feats if k not in ['spatial', module_name]] \ + + [feat_prop] + + feat = torch.cat(feat, dim=1) + # embed current features + feat_prop = feat_prop + self.backbone[module_name](feat) + + feats[module_name].append(feat_prop) + + # end for + if 'backward' in module_name: + feats[module_name] = feats[module_name][::-1] + + outputs = [] + for i in range(0, t): + align_feats = [feats[k].pop(0) for k in feats if k != 'spatial'] + align_feats = torch.cat(align_feats, dim=1) + outputs.append(self.fusion(align_feats)) + + return torch.stack(outputs, dim=1) + x + + +class deconv(nn.Module): + def __init__(self, + input_channel, + output_channel, + kernel_size=3, + padding=0): + super().__init__() + self.conv = nn.Conv2d(input_channel, + output_channel, + kernel_size=kernel_size, + stride=1, + padding=padding) + + def forward(self, x): + x = F.interpolate(x, + scale_factor=2, + mode='bilinear', + align_corners=True) + return self.conv(x) + + +class P3DBlock(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size, stride, padding, use_residual=0, bias=True): + super().__init__() + self.conv1 = nn.Sequential( + nn.Conv3d(in_channels, out_channels, kernel_size=(1, kernel_size, kernel_size), + stride=(1, stride, stride), padding=(0, padding, padding), bias=bias), + nn.LeakyReLU(0.2, inplace=True) + ) + self.conv2 = nn.Sequential( + nn.Conv3d(out_channels, out_channels, kernel_size=(3, 1, 1), stride=(1, 1, 1), + padding=(2, 0, 0), dilation=(2, 1, 1), bias=bias) + ) + self.use_residual = use_residual + + def forward(self, feats): + feat1 = self.conv1(feats) + feat2 = self.conv2(feat1) + if self.use_residual: + output = feats + feat2 + else: + output = feat2 + return output + + +class EdgeDetection(nn.Module): + def __init__(self, in_ch=2, out_ch=1, mid_ch=16): + super().__init__() + self.projection = nn.Sequential( + nn.Conv2d(in_ch, mid_ch, 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True) + ) + + self.mid_layer_1 = nn.Sequential( + nn.Conv2d(mid_ch, mid_ch, 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True) + ) + + self.mid_layer_2 = nn.Sequential( + nn.Conv2d(mid_ch, mid_ch, 3, 1, 1) + ) + + self.l_relu = nn.LeakyReLU(0.01, inplace=True) + + self.out_layer = nn.Conv2d(mid_ch, out_ch, 1, 1, 0) + + def forward(self, flow): + flow = self.projection(flow) + edge = self.mid_layer_1(flow) + edge = self.mid_layer_2(edge) + edge = self.l_relu(flow + edge) + edge = self.out_layer(edge) + edge = torch.sigmoid(edge) + return edge + + +class RecurrentFlowCompleteNet(nn.Module): + def __init__(self, model_path=None): + super().__init__() + self.downsample = nn.Sequential( + nn.Conv3d(3, 32, kernel_size=(1, 5, 5), stride=(1, 2, 2), + padding=(0, 2, 2), padding_mode='replicate'), + nn.LeakyReLU(0.2, inplace=True) + ) + + self.encoder1 = nn.Sequential( + P3DBlock(32, 32, 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + P3DBlock(32, 64, 3, 2, 1), + nn.LeakyReLU(0.2, inplace=True) + ) # 4x + + self.encoder2 = nn.Sequential( + P3DBlock(64, 64, 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + P3DBlock(64, 128, 3, 2, 1), + nn.LeakyReLU(0.2, inplace=True) + ) # 8x + + self.mid_dilation = nn.Sequential( + nn.Conv3d(128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 3, 3), dilation=(1, 3, 3)), # p = d*(k-1)/2 + nn.LeakyReLU(0.2, inplace=True), + nn.Conv3d(128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 2, 2), dilation=(1, 2, 2)), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv3d(128, 128, (1, 3, 3), (1, 1, 1), padding=(0, 1, 1), dilation=(1, 1, 1)), + nn.LeakyReLU(0.2, inplace=True) + ) + + # feature propagation module + self.feat_prop_module = BidirectionalPropagation(128) + + self.decoder2 = nn.Sequential( + nn.Conv2d(128, 128, 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + deconv(128, 64, 3, 1), + nn.LeakyReLU(0.2, inplace=True) + ) # 4x + + self.decoder1 = nn.Sequential( + nn.Conv2d(64, 64, 3, 1, 1), + nn.LeakyReLU(0.2, inplace=True), + deconv(64, 32, 3, 1), + nn.LeakyReLU(0.2, inplace=True) + ) # 2x + + self.upsample = nn.Sequential( + nn.Conv2d(32, 32, 3, padding=1), + nn.LeakyReLU(0.2, inplace=True), + deconv(32, 2, 3, 1) + ) + + # edge loss + self.edgeDetector = EdgeDetection(in_ch=2, out_ch=1, mid_ch=16) + + # Need to initial the weights of MSDeformAttn specifically + for m in self.modules(): + if isinstance(m, SecondOrderDeformableAlignment): + m.init_offset() + + if model_path is not None: + # print('Pretrained flow completion model has loaded...') + ckpt = torch.load(model_path, map_location='cpu') + self.load_state_dict(ckpt, strict=True) + + + def forward(self, masked_flows, masks): + # masked_flows: b t-1 2 h w + # masks: b t-1 2 h w + b, t, _, h, w = masked_flows.size() + masked_flows = masked_flows.permute(0,2,1,3,4) + masks = masks.permute(0,2,1,3,4) + + inputs = torch.cat((masked_flows, masks), dim=1) + + x = self.downsample(inputs) + + feat_e1 = self.encoder1(x) + feat_e2 = self.encoder2(feat_e1) # b c t h w + feat_mid = self.mid_dilation(feat_e2) # b c t h w + feat_mid = feat_mid.permute(0,2,1,3,4) # b t c h w + + feat_prop = self.feat_prop_module(feat_mid) + feat_prop = feat_prop.view(-1, 128, h//8, w//8) # b*t c h w + + _, c, _, h_f, w_f = feat_e1.shape + feat_e1 = feat_e1.permute(0,2,1,3,4).contiguous().view(-1, c, h_f, w_f) # b*t c h w + feat_d2 = self.decoder2(feat_prop) + feat_e1 + + _, c, _, h_f, w_f = x.shape + x = x.permute(0,2,1,3,4).contiguous().view(-1, c, h_f, w_f) # b*t c h w + + feat_d1 = self.decoder1(feat_d2) + + flow = self.upsample(feat_d1) + if self.training: + edge = self.edgeDetector(flow) + edge = edge.view(b, t, 1, h, w) + else: + edge = None + + flow = flow.view(b, t, 2, h, w) + + return flow, edge + + + def forward_bidirect_flow(self, masked_flows_bi, masks): + """ + Args: + masked_flows_bi: [masked_flows_f, masked_flows_b] | (b t-1 2 h w), (b t-1 2 h w) + masks: b t 1 h w + """ + masks_forward = masks[:, :-1, ...].contiguous() + masks_backward = masks[:, 1:, ...].contiguous() + + # mask flow + masked_flows_forward = masked_flows_bi[0] * (1-masks_forward) + masked_flows_backward = masked_flows_bi[1] * (1-masks_backward) + + # -- completion -- + # forward + pred_flows_forward, pred_edges_forward = self.forward(masked_flows_forward, masks_forward) + + # backward + masked_flows_backward = torch.flip(masked_flows_backward, dims=[1]) + masks_backward = torch.flip(masks_backward, dims=[1]) + pred_flows_backward, pred_edges_backward = self.forward(masked_flows_backward, masks_backward) + pred_flows_backward = torch.flip(pred_flows_backward, dims=[1]) + if self.training: + pred_edges_backward = torch.flip(pred_edges_backward, dims=[1]) + + return [pred_flows_forward, pred_flows_backward], [pred_edges_forward, pred_edges_backward] + + + def combine_flow(self, masked_flows_bi, pred_flows_bi, masks): + masks_forward = masks[:, :-1, ...].contiguous() + masks_backward = masks[:, 1:, ...].contiguous() + + pred_flows_forward = pred_flows_bi[0] * masks_forward + masked_flows_bi[0] * (1-masks_forward) + pred_flows_backward = pred_flows_bi[1] * masks_backward + masked_flows_bi[1] * (1-masks_backward) + + return pred_flows_forward, pred_flows_backward diff --git a/propainter/model/vgg_arch.py b/propainter/model/vgg_arch.py new file mode 100644 index 0000000000000000000000000000000000000000..43fc2ff8bc1c73313d632c6ab326372d389a4772 --- /dev/null +++ b/propainter/model/vgg_arch.py @@ -0,0 +1,157 @@ +import os +import torch +from collections import OrderedDict +from torch import nn as nn +from torchvision.models import vgg as vgg + +VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth' +NAMES = { + 'vgg11': [ + 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', + 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', + 'pool5' + ], + 'vgg13': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', + 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5' + ], + 'vgg16': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', + 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', + 'pool5' + ], + 'vgg19': [ + 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2', + 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1', + 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1', + 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5' + ] +} + + +def insert_bn(names): + """Insert bn layer after each conv. + + Args: + names (list): The list of layer names. + + Returns: + list: The list of layer names with bn layers. + """ + names_bn = [] + for name in names: + names_bn.append(name) + if 'conv' in name: + position = name.replace('conv', '') + names_bn.append('bn' + position) + return names_bn + +class VGGFeatureExtractor(nn.Module): + """VGG network for feature extraction. + + In this implementation, we allow users to choose whether use normalization + in the input feature and the type of vgg network. Note that the pretrained + path must fit the vgg type. + + Args: + layer_name_list (list[str]): Forward function returns the corresponding + features according to the layer_name_list. + Example: {'relu1_1', 'relu2_1', 'relu3_1'}. + vgg_type (str): Set the type of vgg network. Default: 'vgg19'. + use_input_norm (bool): If True, normalize the input image. Importantly, + the input feature must in the range [0, 1]. Default: True. + range_norm (bool): If True, norm images with range [-1, 1] to [0, 1]. + Default: False. + requires_grad (bool): If true, the parameters of VGG network will be + optimized. Default: False. + remove_pooling (bool): If true, the max pooling operations in VGG net + will be removed. Default: False. + pooling_stride (int): The stride of max pooling operation. Default: 2. + """ + + def __init__(self, + layer_name_list, + vgg_type='vgg19', + use_input_norm=True, + range_norm=False, + requires_grad=False, + remove_pooling=False, + pooling_stride=2): + super(VGGFeatureExtractor, self).__init__() + + self.layer_name_list = layer_name_list + self.use_input_norm = use_input_norm + self.range_norm = range_norm + + self.names = NAMES[vgg_type.replace('_bn', '')] + if 'bn' in vgg_type: + self.names = insert_bn(self.names) + + # only borrow layers that will be used to avoid unused params + max_idx = 0 + for v in layer_name_list: + idx = self.names.index(v) + if idx > max_idx: + max_idx = idx + + if os.path.exists(VGG_PRETRAIN_PATH): + vgg_net = getattr(vgg, vgg_type)(pretrained=False) + state_dict = torch.load(VGG_PRETRAIN_PATH, map_location=lambda storage, loc: storage) + vgg_net.load_state_dict(state_dict) + else: + vgg_net = getattr(vgg, vgg_type)(pretrained=True) + + features = vgg_net.features[:max_idx + 1] + + modified_net = OrderedDict() + for k, v in zip(self.names, features): + if 'pool' in k: + # if remove_pooling is true, pooling operation will be removed + if remove_pooling: + continue + else: + # in some cases, we may want to change the default stride + modified_net[k] = nn.MaxPool2d(kernel_size=2, stride=pooling_stride) + else: + modified_net[k] = v + + self.vgg_net = nn.Sequential(modified_net) + + if not requires_grad: + self.vgg_net.eval() + for param in self.parameters(): + param.requires_grad = False + else: + self.vgg_net.train() + for param in self.parameters(): + param.requires_grad = True + + if self.use_input_norm: + # the mean is for image with range [0, 1] + self.register_buffer('mean', torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + # the std is for image with range [0, 1] + self.register_buffer('std', torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + def forward(self, x): + """Forward function. + + Args: + x (Tensor): Input tensor with shape (n, c, h, w). + + Returns: + Tensor: Forward results. + """ + if self.range_norm: + x = (x + 1) / 2 + if self.use_input_norm: + x = (x - self.mean) / self.std + output = {} + + for key, layer in self.vgg_net._modules.items(): + x = layer(x) + if key in self.layer_name_list: + output[key] = x.clone() + + return output diff --git a/propainter/utils/download_util.py b/propainter/utils/download_util.py new file mode 100644 index 0000000000000000000000000000000000000000..5e8fb1b00522309d0c0931f5396355011fb200e7 --- /dev/null +++ b/propainter/utils/download_util.py @@ -0,0 +1,109 @@ +import math +import os +import requests +from torch.hub import download_url_to_file, get_dir +from tqdm import tqdm +from urllib.parse import urlparse + +def sizeof_fmt(size, suffix='B'): + """Get human readable file size. + + Args: + size (int): File size. + suffix (str): Suffix. Default: 'B'. + + Return: + str: Formated file siz. + """ + for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: + if abs(size) < 1024.0: + return f'{size:3.1f} {unit}{suffix}' + size /= 1024.0 + return f'{size:3.1f} Y{suffix}' + + +def download_file_from_google_drive(file_id, save_path): + """Download files from google drive. + Ref: + https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 + Args: + file_id (str): File id. + save_path (str): Save path. + """ + + session = requests.Session() + URL = 'https://docs.google.com/uc?export=download' + params = {'id': file_id} + + response = session.get(URL, params=params, stream=True) + token = get_confirm_token(response) + if token: + params['confirm'] = token + response = session.get(URL, params=params, stream=True) + + # get file size + response_file_size = session.get(URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) + print(response_file_size) + if 'Content-Range' in response_file_size.headers: + file_size = int(response_file_size.headers['Content-Range'].split('/')[1]) + else: + file_size = None + + save_response_content(response, save_path, file_size) + + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + return None + + +def save_response_content(response, destination, file_size=None, chunk_size=32768): + if file_size is not None: + pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') + + readable_file_size = sizeof_fmt(file_size) + else: + pbar = None + + with open(destination, 'wb') as f: + downloaded_size = 0 + for chunk in response.iter_content(chunk_size): + downloaded_size += chunk_size + if pbar is not None: + pbar.update(1) + pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} / {readable_file_size}') + if chunk: # filter out keep-alive new chunks + f.write(chunk) + if pbar is not None: + pbar.close() + + +def load_file_from_url(url, model_dir=None, progress=True, file_name=None): + """Load file form http url, will download models if necessary. + Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py + Args: + url (str): URL to be downloaded. + model_dir (str): The path to save the downloaded model. Should be a full path. If None, use pytorch hub_dir. + Default: None. + progress (bool): Whether to show the download progress. Default: True. + file_name (str): The downloaded file name. If None, use the file name in the url. Default: None. + Returns: + str: The path to the downloaded file. + """ + if model_dir is None: # use the pytorch hub_dir + hub_dir = get_dir() + model_dir = os.path.join(hub_dir, 'checkpoints') + + os.makedirs(model_dir, exist_ok=True) + + parts = urlparse(url) + filename = os.path.basename(parts.path) + if file_name is not None: + filename = file_name + cached_file = os.path.abspath(os.path.join(model_dir, filename)) + if not os.path.exists(cached_file): + print(f'Downloading: "{url}" to {cached_file}\n') + download_url_to_file(url, cached_file, hash_prefix=None, progress=progress) + return cached_file \ No newline at end of file diff --git a/propainter/utils/file_client.py b/propainter/utils/file_client.py new file mode 100644 index 0000000000000000000000000000000000000000..c7578dec8304475b3906d5dcc734a5a58b56f2ad --- /dev/null +++ b/propainter/utils/file_client.py @@ -0,0 +1,166 @@ +from abc import ABCMeta, abstractmethod + + +class BaseStorageBackend(metaclass=ABCMeta): + """Abstract class of storage backends. + + All backends need to implement two apis: ``get()`` and ``get_text()``. + ``get()`` reads the file as a byte stream and ``get_text()`` reads the file + as texts. + """ + + @abstractmethod + def get(self, filepath): + pass + + @abstractmethod + def get_text(self, filepath): + pass + + +class MemcachedBackend(BaseStorageBackend): + """Memcached storage backend. + + Attributes: + server_list_cfg (str): Config file for memcached server list. + client_cfg (str): Config file for memcached client. + sys_path (str | None): Additional path to be appended to `sys.path`. + Default: None. + """ + + def __init__(self, server_list_cfg, client_cfg, sys_path=None): + if sys_path is not None: + import sys + sys.path.append(sys_path) + try: + import mc + except ImportError: + raise ImportError('Please install memcached to enable MemcachedBackend.') + + self.server_list_cfg = server_list_cfg + self.client_cfg = client_cfg + self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, self.client_cfg) + # mc.pyvector servers as a point which points to a memory cache + self._mc_buffer = mc.pyvector() + + def get(self, filepath): + filepath = str(filepath) + import mc + self._client.Get(filepath, self._mc_buffer) + value_buf = mc.ConvertBuffer(self._mc_buffer) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class HardDiskBackend(BaseStorageBackend): + """Raw hard disks storage backend.""" + + def get(self, filepath): + filepath = str(filepath) + with open(filepath, 'rb') as f: + value_buf = f.read() + return value_buf + + def get_text(self, filepath): + filepath = str(filepath) + with open(filepath, 'r') as f: + value_buf = f.read() + return value_buf + + +class LmdbBackend(BaseStorageBackend): + """Lmdb storage backend. + + Args: + db_paths (str | list[str]): Lmdb database paths. + client_keys (str | list[str]): Lmdb client keys. Default: 'default'. + readonly (bool, optional): Lmdb environment parameter. If True, + disallow any write operations. Default: True. + lock (bool, optional): Lmdb environment parameter. If False, when + concurrent access occurs, do not lock the database. Default: False. + readahead (bool, optional): Lmdb environment parameter. If False, + disable the OS filesystem readahead mechanism, which may improve + random read performance when a database is larger than RAM. + Default: False. + + Attributes: + db_paths (list): Lmdb database path. + _client (list): A list of several lmdb envs. + """ + + def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, readahead=False, **kwargs): + try: + import lmdb + except ImportError: + raise ImportError('Please install lmdb to enable LmdbBackend.') + + if isinstance(client_keys, str): + client_keys = [client_keys] + + if isinstance(db_paths, list): + self.db_paths = [str(v) for v in db_paths] + elif isinstance(db_paths, str): + self.db_paths = [str(db_paths)] + assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, ' + f'but received {len(client_keys)} and {len(self.db_paths)}.') + + self._client = {} + for client, path in zip(client_keys, self.db_paths): + self._client[client] = lmdb.open(path, readonly=readonly, lock=lock, readahead=readahead, **kwargs) + + def get(self, filepath, client_key): + """Get values according to the filepath from one lmdb named client_key. + + Args: + filepath (str | obj:`Path`): Here, filepath is the lmdb key. + client_key (str): Used for distinguishing differnet lmdb envs. + """ + filepath = str(filepath) + assert client_key in self._client, (f'client_key {client_key} is not ' 'in lmdb clients.') + client = self._client[client_key] + with client.begin(write=False) as txn: + value_buf = txn.get(filepath.encode('ascii')) + return value_buf + + def get_text(self, filepath): + raise NotImplementedError + + +class FileClient(object): + """A general file client to access files in different backend. + + The client loads a file or text in a specified backend from its path + and return it as a binary file. it can also register other backend + accessor with a given name and backend class. + + Attributes: + backend (str): The storage backend type. Options are "disk", + "memcached" and "lmdb". + client (:obj:`BaseStorageBackend`): The backend object. + """ + + _backends = { + 'disk': HardDiskBackend, + 'memcached': MemcachedBackend, + 'lmdb': LmdbBackend, + } + + def __init__(self, backend='disk', **kwargs): + if backend not in self._backends: + raise ValueError(f'Backend {backend} is not supported. Currently supported ones' + f' are {list(self._backends.keys())}') + self.backend = backend + self.client = self._backends[backend](**kwargs) + + def get(self, filepath, client_key='default'): + # client_key is used only for lmdb, where different fileclients have + # different lmdb environments. + if self.backend == 'lmdb': + return self.client.get(filepath, client_key) + else: + return self.client.get(filepath) + + def get_text(self, filepath): + return self.client.get_text(filepath) diff --git a/propainter/utils/flow_util.py b/propainter/utils/flow_util.py new file mode 100644 index 0000000000000000000000000000000000000000..5551e5f173414db8c1088a8fd26b4b0f525fa944 --- /dev/null +++ b/propainter/utils/flow_util.py @@ -0,0 +1,196 @@ +import cv2 +import numpy as np +import os +import torch.nn.functional as F + +def resize_flow(flow, newh, neww): + oldh, oldw = flow.shape[0:2] + flow = cv2.resize(flow, (neww, newh), interpolation=cv2.INTER_LINEAR) + flow[:, :, 0] *= neww / oldw + flow[:, :, 1] *= newh / oldh + return flow + +def resize_flow_pytorch(flow, newh, neww): + oldh, oldw = flow.shape[-2:] + flow = F.interpolate(flow, (newh, neww), mode='bilinear') + flow[:, :, 0] *= neww / oldw + flow[:, :, 1] *= newh / oldh + return flow + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): + """Read an optical flow map. + + Args: + flow_path (ndarray or str): Flow path. + quantize (bool): whether to read quantized pair, if set to True, + remaining args will be passed to :func:`dequantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + + Returns: + ndarray: Optical flow represented as a (h, w, 2) numpy array + """ + if quantize: + assert concat_axis in [0, 1] + cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) + if cat_flow.ndim != 2: + raise IOError(f'{flow_path} is not a valid quantized flow file, its dimension is {cat_flow.ndim}.') + assert cat_flow.shape[concat_axis] % 2 == 0 + dx, dy = np.split(cat_flow, 2, axis=concat_axis) + flow = dequantize_flow(dx, dy, *args, **kwargs) + else: + with open(flow_path, 'rb') as f: + try: + header = f.read(4).decode('utf-8') + except Exception: + raise IOError(f'Invalid flow file: {flow_path}') + else: + if header != 'PIEH': + raise IOError(f'Invalid flow file: {flow_path}, header does not contain PIEH') + + w = np.fromfile(f, np.int32, 1).squeeze() + h = np.fromfile(f, np.int32, 1).squeeze() + # flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) + flow = np.fromfile(f, np.float16, w * h * 2).reshape((h, w, 2)) + + return flow.astype(np.float32) + + +def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): + """Write optical flow to file. + + If the flow is not quantized, it will be saved as a .flo file losslessly, + otherwise a jpeg image which is lossy but of much smaller size. (dx and dy + will be concatenated horizontally into a single image if quantize is True.) + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + filename (str): Output filepath. + quantize (bool): Whether to quantize the flow and save it to 2 jpeg + images. If set to True, remaining args will be passed to + :func:`quantize_flow`. + concat_axis (int): The axis that dx and dy are concatenated, + can be either 0 or 1. Ignored if quantize is False. + """ + dir_name = os.path.abspath(os.path.dirname(filename)) + os.makedirs(dir_name, exist_ok=True) + if not quantize: + with open(filename, 'wb') as f: + f.write('PIEH'.encode('utf-8')) + np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) + # flow = flow.astype(np.float32) + flow = flow.astype(np.float16) + flow.tofile(f) + f.flush() + else: + assert concat_axis in [0, 1] + dx, dy = quantize_flow(flow, *args, **kwargs) + dxdy = np.concatenate((dx, dy), axis=concat_axis) + # os.makedirs(os.path.dirname(filename), exist_ok=True) + cv2.imwrite(filename, dxdy) + # imwrite(dxdy, filename) + + +def quantize_flow(flow, max_val=0.02, norm=True): + """Quantize flow to [0, 255]. + + After this step, the size of flow will be much smaller, and can be + dumped as jpeg images. + + Args: + flow (ndarray): (h, w, 2) array of optical flow. + max_val (float): Maximum value of flow, values beyond + [-max_val, max_val] will be truncated. + norm (bool): Whether to divide flow values by image width/height. + + Returns: + tuple[ndarray]: Quantized dx and dy. + """ + h, w, _ = flow.shape + dx = flow[..., 0] + dy = flow[..., 1] + if norm: + dx = dx / w # avoid inplace operations + dy = dy / h + # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. + flow_comps = [quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]] + return tuple(flow_comps) + + +def dequantize_flow(dx, dy, max_val=0.02, denorm=True): + """Recover from quantized flow. + + Args: + dx (ndarray): Quantized dx. + dy (ndarray): Quantized dy. + max_val (float): Maximum value used when quantizing. + denorm (bool): Whether to multiply flow values with width/height. + + Returns: + ndarray: Dequantized flow. + """ + assert dx.shape == dy.shape + assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) + + dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] + + if denorm: + dx *= dx.shape[1] + dy *= dx.shape[0] + flow = np.dstack((dx, dy)) + return flow + + +def quantize(arr, min_val, max_val, levels, dtype=np.int64): + """Quantize an array of (-inf, inf) to [0, levels-1]. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the quantized array. + + Returns: + tuple: Quantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError(f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + arr = np.clip(arr, min_val, max_val) - min_val + quantized_arr = np.minimum(np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) + + return quantized_arr + + +def dequantize(arr, min_val, max_val, levels, dtype=np.float64): + """Dequantize an array. + + Args: + arr (ndarray): Input array. + min_val (scalar): Minimum value to be clipped. + max_val (scalar): Maximum value to be clipped. + levels (int): Quantization levels. + dtype (np.type): The type of the dequantized array. + + Returns: + tuple: Dequantized array. + """ + if not (isinstance(levels, int) and levels > 1): + raise ValueError(f'levels must be a positive integer, but got {levels}') + if min_val >= max_val: + raise ValueError(f'min_val ({min_val}) must be smaller than max_val ({max_val})') + + dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - min_val) / levels + min_val + + return dequantized_arr \ No newline at end of file diff --git a/propainter/utils/img_util.py b/propainter/utils/img_util.py new file mode 100644 index 0000000000000000000000000000000000000000..d409a132ff216e6943a276fb5d8cd5f410824883 --- /dev/null +++ b/propainter/utils/img_util.py @@ -0,0 +1,170 @@ +import cv2 +import math +import numpy as np +import os +import torch +from torchvision.utils import make_grid + + +def img2tensor(imgs, bgr2rgb=True, float32=True): + """Numpy array to tensor. + + Args: + imgs (list[ndarray] | ndarray): Input images. + bgr2rgb (bool): Whether to change bgr to rgb. + float32 (bool): Whether to change to float32. + + Returns: + list[tensor] | tensor: Tensor images. If returned results only have + one element, just return tensor. + """ + + def _totensor(img, bgr2rgb, float32): + if img.shape[2] == 3 and bgr2rgb: + if img.dtype == 'float64': + img = img.astype('float32') + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + img = torch.from_numpy(img.transpose(2, 0, 1)) + if float32: + img = img.float() + return img + + if isinstance(imgs, list): + return [_totensor(img, bgr2rgb, float32) for img in imgs] + else: + return _totensor(imgs, bgr2rgb, float32) + + +def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): + """Convert torch Tensors into image numpy arrays. + + After clamping to [min, max], values will be normalized to [0, 1]. + + Args: + tensor (Tensor or list[Tensor]): Accept shapes: + 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); + 2) 3D Tensor of shape (3/1 x H x W); + 3) 2D Tensor of shape (H x W). + Tensor channel should be in RGB order. + rgb2bgr (bool): Whether to change rgb to bgr. + out_type (numpy type): output types. If ``np.uint8``, transform outputs + to uint8 type with range [0, 255]; otherwise, float type with + range [0, 1]. Default: ``np.uint8``. + min_max (tuple[int]): min and max values for clamp. + + Returns: + (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of + shape (H x W). The channel order is BGR. + """ + if not (torch.is_tensor(tensor) or (isinstance(tensor, list) and all(torch.is_tensor(t) for t in tensor))): + raise TypeError(f'tensor or list of tensors expected, got {type(tensor)}') + + if torch.is_tensor(tensor): + tensor = [tensor] + result = [] + for _tensor in tensor: + _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) + _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) + + n_dim = _tensor.dim() + if n_dim == 4: + img_np = make_grid(_tensor, nrow=int(math.sqrt(_tensor.size(0))), normalize=False).numpy() + img_np = img_np.transpose(1, 2, 0) + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 3: + img_np = _tensor.numpy() + img_np = img_np.transpose(1, 2, 0) + if img_np.shape[2] == 1: # gray image + img_np = np.squeeze(img_np, axis=2) + else: + if rgb2bgr: + img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) + elif n_dim == 2: + img_np = _tensor.numpy() + else: + raise TypeError('Only support 4D, 3D or 2D tensor. ' f'But received with dimension: {n_dim}') + if out_type == np.uint8: + # Unlike MATLAB, numpy.unit8() WILL NOT round by default. + img_np = (img_np * 255.0).round() + img_np = img_np.astype(out_type) + result.append(img_np) + if len(result) == 1: + result = result[0] + return result + + +def tensor2img_fast(tensor, rgb2bgr=True, min_max=(0, 1)): + """This implementation is slightly faster than tensor2img. + It now only supports torch tensor with shape (1, c, h, w). + + Args: + tensor (Tensor): Now only support torch tensor with (1, c, h, w). + rgb2bgr (bool): Whether to change rgb to bgr. Default: True. + min_max (tuple[int]): min and max values for clamp. + """ + output = tensor.squeeze(0).detach().clamp_(*min_max).permute(1, 2, 0) + output = (output - min_max[0]) / (min_max[1] - min_max[0]) * 255 + output = output.type(torch.uint8).cpu().numpy() + if rgb2bgr: + output = cv2.cvtColor(output, cv2.COLOR_RGB2BGR) + return output + + +def imfrombytes(content, flag='color', float32=False): + """Read an image from bytes. + + Args: + content (bytes): Image bytes got from files or other streams. + flag (str): Flags specifying the color type of a loaded image, + candidates are `color`, `grayscale` and `unchanged`. + float32 (bool): Whether to change to float32., If True, will also norm + to [0, 1]. Default: False. + + Returns: + ndarray: Loaded image array. + """ + img_np = np.frombuffer(content, np.uint8) + imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED} + img = cv2.imdecode(img_np, imread_flags[flag]) + if float32: + img = img.astype(np.float32) / 255. + return img + + +def imwrite(img, file_path, params=None, auto_mkdir=True): + """Write image to file. + + Args: + img (ndarray): Image array to be written. + file_path (str): Image file path. + params (None or list): Same as opencv's :func:`imwrite` interface. + auto_mkdir (bool): If the parent folder of `file_path` does not exist, + whether to create it automatically. + + Returns: + bool: Successful or not. + """ + if auto_mkdir: + dir_name = os.path.abspath(os.path.dirname(file_path)) + os.makedirs(dir_name, exist_ok=True) + return cv2.imwrite(file_path, img, params) + + +def crop_border(imgs, crop_border): + """Crop borders of images. + + Args: + imgs (list[ndarray] | ndarray): Images with shape (h, w, c). + crop_border (int): Crop border for each end of height and weight. + + Returns: + list[ndarray]: Cropped images. + """ + if crop_border == 0: + return imgs + else: + if isinstance(imgs, list): + return [v[crop_border:-crop_border, crop_border:-crop_border, ...] for v in imgs] + else: + return imgs[crop_border:-crop_border, crop_border:-crop_border, ...] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..832c939b315597309dcf0324582c296813760fe2 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,19 @@ +torch==2.3.1 +torchvision==0.18.1 +torchaudio==2.3.1 +diffusers==0.29.2 +accelerate==0.25.0 +opencv-python==4.9.0.80 +imageio==2.34.1 +matplotlib +transformers==4.41.1 +einops==0.8.0 +datasets==2.19.1 +numpy==1.26.4 +pillow==10.4.0 +tqdm==4.66.4 +urllib3==2.2.2 +zipp==3.19.2 +peft==0.13.2 +scipy==1.13.1 +av==14.0.1 \ No newline at end of file diff --git a/run_diffueraser.py b/run_diffueraser.py new file mode 100644 index 0000000000000000000000000000000000000000..473c047462f992c156302d8f106f82362bb175fd --- /dev/null +++ b/run_diffueraser.py @@ -0,0 +1,62 @@ +import torch +import os +import time +import argparse +from diffueraser.diffueraser import DiffuEraser +from propainter.inference import Propainter, get_device + +def main(): + + ## input params + parser = argparse.ArgumentParser() + parser.add_argument('--input_video', type=str, default="examples/example3/video.mp4", help='Path to the input video') + parser.add_argument('--input_mask', type=str, default="examples/example3/mask.mp4" , help='Path to the input mask') + parser.add_argument('--video_length', type=int, default=10, help='The maximum length of output video') + parser.add_argument('--mask_dilation_iter', type=int, default=8, help='Adjust it to change the degree of mask expansion') + parser.add_argument('--max_img_size', type=int, default=960, help='The maximum length of output width and height') + parser.add_argument('--save_path', type=str, default="results" , help='Path to the output') + parser.add_argument('--ref_stride', type=int, default=10, help='Propainter params') + parser.add_argument('--neighbor_length', type=int, default=10, help='Propainter params') + parser.add_argument('--subvideo_length', type=int, default=50, help='Propainter params') + parser.add_argument('--base_model_path', type=str, default="weights/stable-diffusion-v1-5" , help='Path to sd1.5 base model') + parser.add_argument('--vae_path', type=str, default="weights/sd-vae-ft-mse" , help='Path to vae') + parser.add_argument('--diffueraser_path', type=str, default="weights/diffuEraser" , help='Path to DiffuEraser') + parser.add_argument('--propainter_model_dir', type=str, default="weights/propainter" , help='Path to priori model') + args = parser.parse_args() + + if not os.path.exists(args.save_path): + os.makedirs(args.save_path) + priori_path = os.path.join(args.save_path, "priori.mp4") + output_path = os.path.join(args.save_path, "diffueraser_result.mp4") + + ## model initialization + device = get_device() + # PCM params + ckpt = "2-Step" + video_inpainting_sd = DiffuEraser(device, args.base_model_path, args.vae_path, args.diffueraser_path, ckpt=ckpt) + propainter = Propainter(args.propainter_model_dir, device=device) + + start_time = time.time() + + ## priori + propainter.forward(args.input_video, args.input_mask, priori_path, video_length=args.video_length, + ref_stride=args.ref_stride, neighbor_length=args.neighbor_length, subvideo_length = args.subvideo_length, + mask_dilation = args.mask_dilation_iter) + + ## diffueraser + guidance_scale = None # The default value is 0. + video_inpainting_sd.forward(args.input_video, args.input_mask, priori_path, output_path, + max_img_size = args.max_img_size, video_length=args.video_length, mask_dilation_iter=args.mask_dilation_iter, + guidance_scale=guidance_scale) + + end_time = time.time() + inference_time = end_time - start_time + print(f"DiffuEraser inference time: {inference_time:.4f} s") + + torch.cuda.empty_cache() + +if __name__ == '__main__': + main() + + + \ No newline at end of file diff --git a/weights/README.md b/weights/README.md new file mode 100644 index 0000000000000000000000000000000000000000..d6b8403cd6916408580cbc55bc4c83a586c6a2bc --- /dev/null +++ b/weights/README.md @@ -0,0 +1,22 @@ +Put the downloaded pre-trained models to this folder. + +The directory structure will be arranged as: +``` +weights + |- diffuEraser + |-brushnet + |-unet_main + |- stable-diffusion-v1-5 + |-feature_extractor + |-... + |- PCM_Weights + |-sd15 + |- propainter + |-ProPainter.pth + |-raft-things.pth + |-recurrent_flow_completion.pth + |- sd-vae-ft-mse + |-diffusion_pytorch_model.bin + |-... + |- README.md +``` \ No newline at end of file