diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..a6344aac8c09253b3b630fb776ae94478aa0275b --- /dev/null +++ b/.gitattributes @@ -0,0 +1,35 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..13ecfc75b02df61296a4e0b7118c605debba1661 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +app1.py +app2.py +demo_utils1.py +tmp +models \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..c9a97e5905a231cc55e918d9096fcbf2c05bb423 --- /dev/null +++ b/README.md @@ -0,0 +1,141 @@ +--- +title: "RelightVid" +emoji: "💡" +colorFrom: "blue" +colorTo: "green" +sdk: "gradio" # 你的项目使用的 SDK (gradio / streamlit / docker) +app_file: "app.py" # 你的主程序文件 +--- + + + +# RelightVid + +**[RelightVid: Temporal-Consistent Diffusion Model for Video Relighting](https://arxiv.org/abs/2501.16330)** +
+[Ye Fang](https://github.com/Aleafy)\*, +[Zeyi Sun](https://github.com/SunzeY)\*, +[Shangzhan Zhang](https://zhanghe3z.github.io/), +[Tong Wu](https://wutong16.github.io/), +[Yinghao Xu](https://justimyhxu.github.io/), +[Pan Zhang](https://panzhang0212.github.io/), +[Jiaqi Wang](https://myownskyw7.github.io/), +[Gordon Wetzstein](https://web.stanford.edu/~gordonwz/), +[Dahua Lin](http://dahua.site/) + +

*Equal Contribution

+

+ + + + + + +

+ + +![Demo](./assets/demo.gif) + + +## 📜 News +🚀 [2024/6/8] We release our [inference pipeline of Make-it-Real](#⚡-quick-start), including material matching and generation of albedo-only 3D objects. + +🚀 [2024/6/8] [Material library annotations](#📦-data-preparation) generated by GPT-4V and [data engine](#⚡-quick-start) are released! + +🚀 [2024/4/26] The [paper](https://arxiv.org/abs/2404.16829) and [project page](https://sunzey.github.io/Make-it-Real) are released! + +## 💡 Highlights +- 🔥 We first demonstrate that **GPT-4V** can effectively **recognize and describe materials**, allowing our model to precisely identifies and aligns materials with the corresponding components of 3D objects. +- 🔥 We construct a **Material Library** containing thousands of materials with highly +detailed descriptions readily for MLLMs to look up and assign. +- 🔥 **An effective pipeline** for texture segmentation, material identification and matching, enabling the high-quality application of materials to +3D assets. + +## 👨‍💻 Todo +- [ ] Evaluation for Existed and Model-Generated Assets (both code & test assets) +- [ ] More Interactive Demos (huggingface, jupyter) +- [x] Make-it-Real Pipeline Inference Code +- [x] Highly detailed Material Library annotations (generated by GPT-4V) +- [x] Paper and Web Demos + +## 💾 Installation + + ```bash + git clone https://github.com/Aleafy/RelightVid.git + cd RelightVid + + conda create -n relitv python=3.10 + conda activate relitv + + pip install torch==2.1.2 torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118 + pip install -r requirements.txt + ``` + + + + + +## 📦 Data Preparation + 1. **Annotations**: in `data/material_lib/annotations` [folder](data/material_lib/annotations), include: + - Highly-detailed descriptions by GPT-4V: offering thorough descriptions of the material’s visual characteristics and rich semantic information. + - Category-tree: Divided into a hierarchical structure with coarse and fine granularity, it includes over 80 subcategories. + 2. **PBR Maps**: You can download the complete PBR data collection at [Huggingface](https://huggingface.co/datasets/gvecchio/MatSynth/tree/main), or download the data used in our project at [OpenXLab](https://openxlab.org.cn/datasets/YeFang/MatSynth/tree/main) (Recommended). (If you have any questions, please refer to [issue#5](https://github.com/Aleafy/Make_it_Real/issues/5)) + 3. **Material Images(optinal)**: You can download the material images file [here](https://drive.google.com/file/d/1ob7CV6JiaqFyjuCzlmSnBuNRkzt2qMSG/view?usp=sharing), to check and visualize the material appearance. + +
+Make_it_Real
+└── data
+    └── material_lib
+        ├── annotations
+        ├── mat_images
+        └── pbr_maps
+            └── train
+                ├── Ceremic
+                ├── Concrete
+                ├── ...
+                └── Wood
+
+ + + +## ⚡ Quick Start +#### Inference +```bash +python main.py --obj_dir --exp_name --api_key +``` +- To ensure proper network connectivity for GPT-4V, add proxy environment settings in [main.py](https://github.com/Aleafy/Make_it_Real/blob/feb3563d57fbe18abbff8d4abfb48f71cc8f967b/main.py#L18) (optional). Also, please verify the reachability of your [API host](https://github.com/Aleafy/Make_it_Real/blob/feb3563d57fbe18abbff8d4abfb48f71cc8f967b/utils/gpt4_query.py#L68). +- Result visualization (blender engine) is located in the `output/refine_output` dir. You can compare the result with that in `output/ori_output`. + +#### Annotation Engine + +```bash +cd scripts/gpt_anno +python gpt4_query_mat.py +``` +`Note`: Besides functinoning as annotation engine, you can also use this code ([gpt4_query_mat.py](https://github.com/Aleafy/Make_it_Real/blob/main/scripts/gpt_anno/gpt4_query_mat.py)) to test the GPT-4V connection simply. + + + + + + +## ❤️ Acknowledgments +- [MatSynth](https://huggingface.co/datasets/gvecchio/MatSynth/tree/main): a Physically Based Rendering (PBR) materials dataset, which offers extensive high-resolusion tilable pbr maps to look up. +- [TEXTure](https://github.com/TEXTurePaper/TEXTurePaper): Wonderful text-guided texture generation model, and the codebase we built upon. +- [SoM](https://som-gpt4v.github.io/): Draw visual cues on images to facilate GPT-4V query better. +- [Material Palette](https://github.com/astra-vision/MaterialPalette): Excellent exploration of material extraction and generation, offers good insights and comparable setting. + +## ✒️ Citation +If you find our work helpful for your research, please consider giving a star ⭐ and citation 📝 +```bibtex +@misc{fang2024makeitreal, + title={Make-it-Real: Unleashing Large Multimodal Model for Painting 3D Objects with Realistic Materials}, + author={Ye Fang and Zeyi Sun and Tong Wu and Jiaqi Wang and Ziwei Liu and Gordon Wetzstein and Dahua Lin}, + year={2024}, + eprint={2404.16829}, + archivePrefix={arXiv}, + primaryClass={cs.CV} +} +``` + + diff --git a/__pycache__/db_examples.cpython-310.pyc b/__pycache__/db_examples.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4849adc049a20e290b188a48d531b56b97780b45 Binary files /dev/null and b/__pycache__/db_examples.cpython-310.pyc differ diff --git a/__pycache__/demo_utils1.cpython-310.pyc b/__pycache__/demo_utils1.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb3ed080f692b5ee8a6c95968e7d307fb77eca9b Binary files /dev/null and b/__pycache__/demo_utils1.cpython-310.pyc differ diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..2f8f28e520a59e7483bd9204dfaeeea64e5bcfa7 --- /dev/null +++ b/app.py @@ -0,0 +1,365 @@ +import os +import gradio as gr +import numpy as np +from enum import Enum +import db_examples +import cv2 + +# from demo_utils1 import * + +from misc_utils.train_utils import unit_test_create_model +from misc_utils.image_utils import save_tensor_to_gif, save_tensor_to_images +import os +from PIL import Image +import torch +import torchvision +from torchvision import transforms +from einops import rearrange +import imageio +import time + +from torchvision.transforms import functional as F + +import os + +# 推理设置 +from pl_trainer.inference.inference import InferenceIP2PVideo +from tqdm import tqdm + +# 下载文件 +os.makedirs('models', exist_ok=True) +filename = "models/iclight_sd15_fbc.safetensors" + +# if not os.path.exists(filename): +# original_path = os.getcwd() +# base_path = './models' +# os.makedirs(base_path, exist_ok=True) + +# # 直接在代码中写入 Token(注意安全风险) +# GIT_TOKEN = "955b8ea91095840b76fe38b90a088c200d4c813c" +# repo_url = f"https://YeFang:{GIT_TOKEN}@code.openxlab.org.cn/YeFang/RIV_models.git" + +# try: +# if os.system(f'git clone {repo_url} {base_path}') != 0: +# raise RuntimeError("Git 克隆失败") +# os.chdir(base_path) +# if os.system('git lfs pull') != 0: +# raise RuntimeError("Git LFS 拉取失败") +# finally: +# os.chdir(original_path) + +def tensor_to_pil_image(x): + """ + 将 4D PyTorch 张量转换为 PIL 图像。 + """ + x = x.float() # 确保张量类型为 float + grid_img = torchvision.utils.make_grid(x, nrow=4).permute(1, 2, 0).detach().cpu().numpy() + grid_img = (grid_img * 255).clip(0, 255).astype("uint8") # 将 [0, 1] 范围转换为 [0, 255] + return Image.fromarray(grid_img) + +def frame_to_batch(x): + """ + 将帧维度转换为批次维度。 + """ + return rearrange(x, 'b f c h w -> (b f) c h w') + +def clip_image(x, min=0., max=1.): + """ + 将图像张量裁剪到指定的最小和最大值。 + """ + return torch.clamp(x, min=min, max=max) + +def unnormalize(x): + """ + 将张量范围从 [-1, 1] 转换到 [0, 1]。 + """ + return (x + 1) / 2 + + +# 读取图像文件 +def read_images_from_directory(directory, num_frames=16): + images = [] + for i in range(num_frames): + img_path = os.path.join(directory, f'{i:04d}.png') + img = imageio.imread(img_path) + images.append(torch.tensor(img).permute(2, 0, 1)) # Convert to Tensor (C, H, W) + return images + +def load_and_process_images(folder_path): + """ + 读取文件夹中的所有图片,将它们转换为 [-1, 1] 范围的张量并返回一个 4D 张量。 + """ + processed_images = [] + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Lambda(lambda x: x * 2 - 1) # 将 [0, 1] 转换为 [-1, 1] + ]) + for filename in sorted(os.listdir(folder_path)): + if filename.endswith(".png"): + img_path = os.path.join(folder_path, filename) + image = Image.open(img_path).convert("RGB") + processed_image = transform(image) + processed_images.append(processed_image) + return torch.stack(processed_images) # 返回 4D 张量 + +def load_and_process_video(video_path, num_frames=16, crop_size=512): + """ + 读取视频文件中的前 num_frames 帧,将每一帧转换为 [-1, 1] 范围的张量, + 并进行中心裁剪至 crop_size x crop_size,返回一个 4D 张量。 + """ + processed_frames = [] + transform = transforms.Compose([ + transforms.CenterCrop(crop_size), # 中心裁剪 + transforms.ToTensor(), + transforms.Lambda(lambda x: x * 2 - 1) # 将 [0, 1] 转换为 [-1, 1] + ]) + + # 使用 OpenCV 读取视频 + cap = cv2.VideoCapture(video_path) + + if not cap.isOpened(): + raise ValueError(f"无法打开视频文件: {video_path}") + + frame_count = 0 + + while frame_count < num_frames: + ret, frame = cap.read() + if not ret: + break # 视频帧读取完毕或视频帧不足 + + # 转换为 RGB 格式 + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + image = Image.fromarray(frame) + + # 应用转换 + processed_frame = transform(image) + processed_frames.append(processed_frame) + + frame_count += 1 + + cap.release() # 释放视频资源 + + if len(processed_frames) < num_frames: + raise ValueError(f"视频帧不足 {num_frames} 帧,仅找到 {len(processed_frames)} 帧。") + + return torch.stack(processed_frames) # 返回 4D 张量 (帧数, 通道数, 高度, 宽度) + + +def clear_cache(output_path): + if os.path.exists(output_path): + os.remove(output_path) + return None + + +#! 加载模型 +# 配置路径和加载模型 +config_path = 'configs/instruct_v2v_ic_gradio.yaml' +diffusion_model = unit_test_create_model(config_path).cuda() + +# 加载模型检查点 +ckpt_path = 'models/pytorch_model.bin' #! change +ckpt = torch.load(ckpt_path, map_location='cpu') +diffusion_model.load_state_dict(ckpt, strict=False) + +# import pdb; pdb.set_trace() + +# # 更改全局临时目录 +# new_tmp_dir = "./demo/gradio_bg" +# os.makedirs(new_tmp_dir, exist_ok=True) + +# import pdb; pdb.set_trace() + +def save_video_from_frames(image_pred, save_pth, fps=8): + """ + 将 image_pred 中的帧保存为视频文件。 + + 参数: + - image_pred: Tensor,形状为 (1, 16, 3, 512, 512) + - save_pth: 保存视频的路径,例如 "output_video.mp4" + - fps: 视频的帧率 + """ + # 视频参数 + num_frames = image_pred.shape[1] + frame_height, frame_width = 512, 512 # 目标尺寸 + fourcc = cv2.VideoWriter_fourcc(*'mp4v') # 使用 mp4 编码格式 + + # 创建 VideoWriter 对象 + out = cv2.VideoWriter(save_pth, fourcc, fps, (frame_width, frame_height)) + + for i in range(num_frames): + # 反归一化 + 转换为 0-255 范围 + pred_frame = clip_image(unnormalize(image_pred[0][i].unsqueeze(0))) * 255 + pred_frame_resized = pred_frame.squeeze(0).detach().cpu() # (3, 512, 512) + pred_frame_resized = pred_frame_resized.permute(1, 2, 0).numpy().astype("uint8") # (512, 512, 3) + + # Resize 到 256x256 + pred_frame_resized = cv2.resize(pred_frame_resized, (frame_width, frame_height)) + + # 将 RGB 转为 BGR(因为 OpenCV 使用 BGR 格式) + pred_frame_bgr = cv2.cvtColor(pred_frame_resized, cv2.COLOR_RGB2BGR) + + # 写入帧到视频 + out.write(pred_frame_bgr) + + # 释放 VideoWriter 资源 + out.release() + print(f"视频已保存至 {save_pth}") + + +# 伪函数占位(生成空白视频) +def dummy_process(input_fg, input_bg): + # import pdb; pdb.set_trace() + fg_tensor = load_and_process_video(input_fg).cuda().unsqueeze(0) + bg_tensor = load_and_process_video(input_bg).cuda().unsqueeze(0) # (1, 16, 4, 64, 64) + + cond_fg_tensor = diffusion_model.encode_image_to_latent(fg_tensor) # (1, 16, 4, 64, 64) + cond_bg_tensor = diffusion_model.encode_image_to_latent(bg_tensor) + cond_tensor = torch.cat((cond_fg_tensor, cond_bg_tensor), dim=2) + + # 初始化潜变量 + init_latent = torch.randn_like(cond_fg_tensor) + + inf_pipe = InferenceIP2PVideo( + diffusion_model.unet, + scheduler='ddpm', + num_ddim_steps=20 + ) + + EDIT_PROMPT = 'change the background' + VIDEO_CFG = 1.2 + TEXT_CFG = 7.5 + text_cond = diffusion_model.encode_text([EDIT_PROMPT]) # (1, 77, 768) + text_uncond = diffusion_model.encode_text(['']) + latent_pred = inf_pipe( + latent=init_latent, + text_cond=text_cond, + text_uncond=text_uncond, + img_cond=cond_tensor, + text_cfg=TEXT_CFG, + img_cfg=VIDEO_CFG, + )['latent'] + + image_pred = diffusion_model.decode_latent_to_image(latent_pred) # (1,16,3,512,512) + output_path = os.path.join(new_tmp_dir, f"output_{int(time.time())}.mp4") + # clear_cache(output_path) + + save_video_from_frames(image_pred, output_path) + # import pdb; pdb.set_trace() + # fps = 8 + # frames = [] + # for i in range(16): + # pred_frame = clip_image(unnormalize(image_pred[0][i].unsqueeze(0))) * 255 + # pred_frame_resized = pred_frame.squeeze(0).detach().cpu() #(3,512,512) + # pred_frame_resized = pred_frame_resized.permute(1, 2, 0).detach().cpu().numpy().astype("uint8") #(512,512,3) np + # Image.fromarray(pred_frame_resized).save(save_pth) + + # # 生成一个简单的黑色视频作为示例 + # output_path = os.path.join(new_tmp_dir, "output.mp4") + # fourcc = cv2.VideoWriter_fourcc(*'mp4v') + # out = cv2.VideoWriter(output_path, fourcc, 20.0, (512, 512)) + + # for _ in range(60): # 生成 3 秒的视频(20fps) + # frame = np.zeros((512, 512, 3), dtype=np.uint8) + # out.write(frame) + # out.release() + + return output_path + +# 枚举类用于背景选择 +class BGSource(Enum): + UPLOAD = "Use Background Video" + UPLOAD_FLIP = "Use Flipped Background Video" + LEFT = "Left Light" + RIGHT = "Right Light" + TOP = "Top Light" + BOTTOM = "Bottom Light" + GREY = "Ambient" + +# Quick prompts 示例 +quick_prompts = [ + 'beautiful woman', + 'handsome man', + 'beautiful woman, cinematic lighting', + 'handsome man, cinematic lighting', + 'beautiful woman, natural lighting', + 'handsome man, natural lighting', + 'beautiful woman, neo punk lighting, cyberpunk', + 'handsome man, neo punk lighting, cyberpunk', +] +quick_prompts = [[x] for x in quick_prompts] + +# Gradio UI 结构 +block = gr.Blocks().queue() +with block: + with gr.Row(): + gr.Markdown("## IC-Light (Relighting with Foreground and Background Video Condition)") + + with gr.Row(): + with gr.Column(): + with gr.Row(): + input_fg = gr.Video(label="Foreground Video", height=370, width=370, visible=True) + input_bg = gr.Video(label="Background Video", height=370, width=370, visible=True) + + prompt = gr.Textbox(label="Prompt") + bg_source = gr.Radio(choices=[e.value for e in BGSource], + value=BGSource.UPLOAD.value, + label="Background Source", type='value') + + example_prompts = gr.Dataset(samples=quick_prompts, label='Prompt Quick List', components=[prompt]) + bg_gallery = gr.Gallery(height=450, object_fit='contain', label='Background Quick List', value=db_examples.bg_samples, columns=5, allow_preview=False) + relight_button = gr.Button(value="Relight") + + with gr.Group(): + with gr.Row(): + num_samples = gr.Slider(label="Videos", minimum=1, maximum=12, value=1, step=1) + seed = gr.Number(label="Seed", value=12345, precision=0) + with gr.Row(): + video_width = gr.Slider(label="Video Width", minimum=256, maximum=1024, value=512, step=64) + video_height = gr.Slider(label="Video Height", minimum=256, maximum=1024, value=640, step=64) + + with gr.Accordion("Advanced options", open=False): + steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=20, step=1) + cfg = gr.Slider(label="CFG Scale", minimum=1.0, maximum=32.0, value=7.0, step=0.01) + highres_scale = gr.Slider(label="Highres Scale", minimum=1.0, maximum=3.0, value=1.5, step=0.01) + highres_denoise = gr.Slider(label="Highres Denoise", minimum=0.1, maximum=0.9, value=0.5, step=0.01) + a_prompt = gr.Textbox(label="Added Prompt", value='best quality') + n_prompt = gr.Textbox(label="Negative Prompt", value='lowres, bad anatomy, bad hands, cropped, worst quality') + normal_button = gr.Button(value="Compute Normal (4x Slower)") + + with gr.Column(): + result_video = gr.Video(label='Output Video', height=600, width=600, visible=True) + + # 输入列表 + # ips = [input_fg, input_bg, prompt, video_width, video_height, num_samples, seed, steps, a_prompt, n_prompt, cfg, highres_scale, highres_denoise, bg_source] + ips = [input_fg, input_bg] + + # 按钮绑定处理函数 + # relight_button.click(fn=lambda: None, inputs=[], outputs=[result_video]) + + relight_button.click(fn=dummy_process, inputs=ips, outputs=[result_video]) + + normal_button.click(fn=dummy_process, inputs=ips, outputs=[result_video]) + + # 背景库选择 + def bg_gallery_selected(gal, evt: gr.SelectData): + # import pdb; pdb.set_trace() + # img_path = gal[evt.index][0] + img_path = db_examples.bg_samples[evt.index] + video_path = img_path.replace('frames/0000.png', 'cropped_video.mp4') + return video_path + + bg_gallery.select(bg_gallery_selected, inputs=bg_gallery, outputs=input_bg) + + # 示例 + # dummy_video_for_outputs = gr.Video(visible=False, label='Result') + gr.Examples( + fn=lambda *args: args[-1], + examples=db_examples.background_conditioned_examples, + inputs=[input_fg, input_bg, prompt, bg_source, video_width, video_height, seed, result_video], + outputs=[result_video], + run_on_click=True, examples_per_page=1024 + ) + +# 启动 Gradio 应用 +# block.launch(server_name='0.0.0.0', server_port=10002, share=True) +block.launch(share=True) diff --git a/configs/instruct_v2v.yaml b/configs/instruct_v2v.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1cc351324b5f82b7ef879954885b49fe52806061 --- /dev/null +++ b/configs/instruct_v2v.yaml @@ -0,0 +1,149 @@ +expt_dir: experiments +expt_name: instruct_v2v_ic +trainer_args: + max_epochs: 10 + accelerator: "gpu" + devices: [0] + limit_train_batches: 2048 + limit_val_batches: 5 #! 这边限制了每个epoch只跑多少个batch的validation + # strategy: "ddp" + strategy: "deepspeed_stage_2" + accumulate_grad_batches: 128 #! 注意一下这个值 + check_val_every_n_epoch: 1 #! check一下这个值是不是和记录有关。。。 +diffusion: + target: pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporal + params: + beta_schedule_args: + beta_schedule: scaled_linear + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + clip_sample: false + thresholding: false + prediction_type: epsilon + loss_fn: l2 + optim_args: + lr: 1e-5 + unet_init_weights: #! 注意一下, 完全可以从iv2v的ckpt开始train + - pretrained_models/instruct_pix2pix/diffusion_pytorch_model.bin # 这边sd加载的是ip2p的 + - pretrained_models/Motion_Module/mm_sd_v15.ckpt # motion module, 推测加载的是animatediff的 + vae_init_weights: pretrained_models/instruct_pix2pix/vqvae.ckpt + text_model_init_weights: pretrained_models/instruct_pix2pix/text.ckpt + scale_factor: 0.18215 + guidance_scale: 5 # not used + ddim_sampling_steps: 20 + text_cfg: 7.5 + img_cfg: 1.2 + cond_image_dropout: 0.1 + prompt_type: edit_prompt +unet: + target: modules.video_unet_temporal.unet.UNet3DConditionModel + params: + in_channels: 4 #! 8-> 4 + out_channels: 4 + act_fn: silu + attention_head_dim: 8 + block_out_channels: + - 320 + - 640 + - 1280 + - 1280 + cross_attention_dim: 768 + down_block_types: + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - DownBlock3D + up_block_types: + - UpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + downsample_padding: 1 + layers_per_block: 2 + mid_block_scale_factor: 1 + norm_eps: 1e-05 + norm_num_groups: 32 + sample_size: 64 + use_motion_module: true + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: false + motion_module_decoder_only: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 +vae: + target: modules.kl_autoencoder.autoencoder.AutoencoderKL + params: + embed_dim: 4 + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity +text_model: + target: modules.openclip.modules.FrozenCLIPEmbedder + params: + freeze: true +data: + batch_size: 1 + val_batch_size: 1 + train: + target: dataset.videoP2P.VideoPromptToPromptMotionAug + params: #注意修改一下training的路径,和相关加载的代码, 比如说没有meta.yaml这些参数怎么搞 + root_dirs: + - /home/fy/Code/instruct-video-to-video/data_train/Girl + num_frames: 16 + zoom_ratio: 0.2 + max_zoom: 1.25 + translation_ratio: 0.7 + translation_range: [0, 0.2] + is_train: True + val: + target: dataset.videoP2P.VideoPromptToPromptMotionAug + params: + root_dirs: + - /home/fy/Code/instruct-video-to-video/data_train/Girl + num_frames: 16 + zoom_ratio: 0.2 + max_zoom: 1.25 + translation_ratio: 0.7 + translation_range: [0, 0.2] +callbacks: + - target: pytorch_lightning.callbacks.ModelCheckpoint + params: + dirpath: "${expt_dir}/${expt_name}" + filename: "{epoch:04d}" + monitor: epoch + mode: max + save_top_k: 5 + save_last: true + - target: callbacks.instruct_p2p_video.InstructP2PLogger + params: + max_num_images: 1 + # accumulate_grad_batches: 128 + require_wandb: true + - target: pytorch_lightning.callbacks.DeviceStatsMonitor \ No newline at end of file diff --git a/configs/instruct_v2v_ic.yaml b/configs/instruct_v2v_ic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1b9f36cc551e76539028fc13565cf86d14fbf947 --- /dev/null +++ b/configs/instruct_v2v_ic.yaml @@ -0,0 +1,130 @@ +expt_dir: experiments +expt_name: instruct_v2v_ic +trainer_args: + max_epochs: 10 + accelerator: "gpu" + devices: [0] + limit_train_batches: 2048 + limit_val_batches: 5 #! 这边限制了每个epoch只跑多少个batch的validation + # strategy: "ddp" + strategy: "deepspeed_stage_2" + accumulate_grad_batches: 128 #! 注意一下这个值 + check_val_every_n_epoch: 1 #! check一下这个值是不是和记录有关。。。 +diffusion: + target: pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporal + params: + beta_schedule_args: + beta_schedule: scaled_linear + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + clip_sample: false + thresholding: false + prediction_type: epsilon + loss_fn: l2 + optim_args: + lr: 1e-5 + unet_init_weights: #! 注意一下, 完全可以从iv2v的ckpt开始train + - unet/diffusion_pytorch_model.safetensors # iclight, unet, sf tensor + - pretrained_models/Motion_Module/mm_sd_v15.ckpt # motion module, 推测加载的是animatediff的 + - pretrained_models/iclight/iclight_sd15_fbc.safetensors # iclight lora weights + base_path: /home/fy/.cache/huggingface/hub/models--stablediffusionapi--realistic-vision-v51/snapshots/19e3643d7d963c156d01537188ec08f0b79a514a + # vae_init_weights: pretrained_models/instruct_pix2pix/vqvae.ckpt + # text_model_init_weights: pretrained_models/instruct_pix2pix/text.ckpt #! 这两个可以直接设为None, 从from_pretrained中加载 + scale_factor: 0.18215 + guidance_scale: 5 # not used + ddim_sampling_steps: 20 + text_cfg: 7.5 + img_cfg: 1.2 + cond_image_dropout: 0.1 + prompt_type: edit_prompt +unet: + target: modules.video_unet_temporal.unet.UNet3DConditionModel + params: + in_channels: 4 #! change:8->12 iclight 改为12 注意一下... + out_channels: 4 + act_fn: silu + attention_head_dim: 8 + block_out_channels: + - 320 + - 640 + - 1280 + - 1280 + cross_attention_dim: 768 + down_block_types: + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - DownBlock3D + up_block_types: + - UpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + downsample_padding: 1 + layers_per_block: 2 + mid_block_scale_factor: 1 + norm_eps: 1e-05 + norm_num_groups: 32 + sample_size: 64 + use_motion_module: false #!!! 这边test iclight的时候可以不用motion module 即False + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: false + motion_module_decoder_only: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 +text_model: + target: modules.openclip.modules.FrozenCLIPEmbedder + params: + freeze: true +data: + batch_size: 1 + val_batch_size: 1 + train: + target: dataset.videoP2P.VideoPromptToPromptMotionAug + params: #注意修改一下training的路径,和相关加载的代码, 比如说没有meta.yaml这些参数怎么搞 + root_dirs: + - /home/fy/Code/instruct-video-to-video/data_train/Girl + num_frames: 16 + zoom_ratio: 0.2 + max_zoom: 1.25 + translation_ratio: 0.7 + translation_range: [0, 0.2] + is_train: True + val: + target: dataset.videoP2P.VideoPromptToPromptMotionAug + params: + root_dirs: + - /home/fy/Code/instruct-video-to-video/data_train/Girl + num_frames: 16 + zoom_ratio: 0.2 + max_zoom: 1.25 + translation_ratio: 0.7 + translation_range: [0, 0.2] +callbacks: + - target: pytorch_lightning.callbacks.ModelCheckpoint + params: + dirpath: "${expt_dir}/${expt_name}" + filename: "{epoch:04d}" + monitor: epoch + mode: max + save_top_k: 5 + save_last: true + - target: callbacks.instruct_p2p_video.InstructP2PLogger + params: + max_num_images: 1 + # accumulate_grad_batches: 128 + require_wandb: true + - target: pytorch_lightning.callbacks.DeviceStatsMonitor \ No newline at end of file diff --git a/configs/instruct_v2v_ic_gradio.yaml b/configs/instruct_v2v_ic_gradio.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4d7cb60bc1ec913c9ac3611ed0f8a243de28ff17 --- /dev/null +++ b/configs/instruct_v2v_ic_gradio.yaml @@ -0,0 +1,81 @@ +diffusion: + target: pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporal + params: + beta_schedule_args: + beta_schedule: scaled_linear + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + clip_sample: false + thresholding: false + prediction_type: epsilon + loss_fn: l2 + optim_args: + lr: 1e-5 + # base_path: models/realistic_v51 + base_path: stablediffusionapi/realistic-vision-v51 + # unet_init_weights: #! 注意一下, 完全可以从iv2v的ckpt开始train + # - diffusion_pytorch_model.safetensors # iclight, unet, sf tensor + # - relvid_mm_sd15_fbc.pth + # - iclight_sd15_fbc.safetensors # iclight lora weights + # base_path: stablediffusionapi/realistic-vision-v51 + # vae_init_weights: pretrained_models/instruct_pix2pix/vqvae.ckpt + # text_model_init_weights: pretrained_models/instruct_pix2pix/text.ckpt #! 这两个可以直接设为None, 从from_pretrained中加载 + scale_factor: 0.18215 + guidance_scale: 5 # not used + ddim_sampling_steps: 20 + text_cfg: 7.5 + img_cfg: 1.2 + cond_image_dropout: 0.1 + prompt_type: edit_prompt +unet: + target: modules.video_unet_temporal.unet.UNet3DConditionModel + params: + in_channels: 4 #! change:8->12 iclight 改为12 注意一下... + out_channels: 4 + act_fn: silu + attention_head_dim: 8 + block_out_channels: + - 320 + - 640 + - 1280 + - 1280 + cross_attention_dim: 768 + down_block_types: + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - DownBlock3D + up_block_types: + - UpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + downsample_padding: 1 + layers_per_block: 2 + mid_block_scale_factor: 1 + norm_eps: 1e-05 + norm_num_groups: 32 + sample_size: 64 + use_motion_module: true #!!! 这边test iclight的时候可以不用motion module 即False + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: false + motion_module_decoder_only: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 +text_model: + target: modules.openclip.modules.FrozenCLIPEmbedder + params: + freeze: true \ No newline at end of file diff --git a/configs/instruct_v2v_ic_inference.yaml b/configs/instruct_v2v_ic_inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..81d73ece9c1b3da21a7c737bbb593f2d01fa1143 --- /dev/null +++ b/configs/instruct_v2v_ic_inference.yaml @@ -0,0 +1,79 @@ +diffusion: + target: pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporal + params: + beta_schedule_args: + beta_schedule: scaled_linear + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + clip_sample: false + thresholding: false + prediction_type: epsilon + loss_fn: l2 + optim_args: + lr: 1e-5 + unet_init_weights: #! 注意一下, 完全可以从iv2v的ckpt开始train + - unet/diffusion_pytorch_model.safetensors # iclight, unet, sf tensor + - pretrained_models/Motion_Module/mm_sd_v15.ckpt # motion module, 推测加载的是animatediff的 + - pretrained_models/iclight/iclight_sd15_fbc.safetensors # iclight lora weights + base_path: /mnt/petrelfs/fangye/.cache/huggingface/hub/models--stablediffusionapi--realistic-vision-v51/snapshots/19e3643d7d963c156d01537188ec08f0b79a514a + # vae_init_weights: pretrained_models/instruct_pix2pix/vqvae.ckpt + # text_model_init_weights: pretrained_models/instruct_pix2pix/text.ckpt #! 这两个可以直接设为None, 从from_pretrained中加载 + scale_factor: 0.18215 + guidance_scale: 5 # not used + ddim_sampling_steps: 20 + text_cfg: 7.5 + img_cfg: 1.2 + cond_image_dropout: 0.1 + prompt_type: edit_prompt +unet: + target: modules.video_unet_temporal.unet.UNet3DConditionModel + params: + in_channels: 4 #! change:8->12 iclight 改为12 注意一下... + out_channels: 4 + act_fn: silu + attention_head_dim: 8 + block_out_channels: + - 320 + - 640 + - 1280 + - 1280 + cross_attention_dim: 768 + down_block_types: + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - DownBlock3D + up_block_types: + - UpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + downsample_padding: 1 + layers_per_block: 2 + mid_block_scale_factor: 1 + norm_eps: 1e-05 + norm_num_groups: 32 + sample_size: 64 + use_motion_module: true #!!! 这边test iclight的时候可以不用motion module 即False + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: false + motion_module_decoder_only: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 +text_model: + target: modules.openclip.modules.FrozenCLIPEmbedder + params: + freeze: true \ No newline at end of file diff --git a/configs/instruct_v2v_ic_inference_hdr.yaml b/configs/instruct_v2v_ic_inference_hdr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..fd16e3cec09e9a521119335732052737e350f745 --- /dev/null +++ b/configs/instruct_v2v_ic_inference_hdr.yaml @@ -0,0 +1,80 @@ +diffusion: + target: pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporalText + params: + beta_schedule_args: + beta_schedule: scaled_linear + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + clip_sample: false + thresholding: false + prediction_type: epsilon + loss_fn: l2 + optim_args: + lr: 1e-5 + unet_init_weights: #! 注意一下, 完全可以从iv2v的ckpt开始train + - unet/diffusion_pytorch_model.safetensors # iclight, unet, sf tensor + - pretrained_models/Motion_Module/mm_sd_v15.ckpt # motion module, 推测加载的是animatediff的 + - pretrained_models/iclight/iclight_sd15_fc.safetensors # iclight lora weights + base_path: /mnt/petrelfs/fangye/.cache/huggingface/hub/models--stablediffusionapi--realistic-vision-v51/snapshots/19e3643d7d963c156d01537188ec08f0b79a514a + # vae_init_weights: pretrained_models/instruct_pix2pix/vqvae.ckpt + # text_model_init_weights: pretrained_models/instruct_pix2pix/text.ckpt #! 这两个可以直接设为None, 从from_pretrained中加载 + scale_factor: 0.18215 + guidance_scale: 5 # not used + ddim_sampling_steps: 20 + text_cfg: 7.5 + img_cfg: 1.2 + cond_image_dropout: 0.1 + prompt_type: edit_prompt + hdr_train: True +unet: + target: modules.video_unet_temporal.unet.UNet3DConditionModel + params: + in_channels: 4 #! change:8->12 iclight 改为12 注意一下... + out_channels: 4 + act_fn: silu + attention_head_dim: 8 + block_out_channels: + - 320 + - 640 + - 1280 + - 1280 + cross_attention_dim: 768 + down_block_types: + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - DownBlock3D + up_block_types: + - UpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + downsample_padding: 1 + layers_per_block: 2 + mid_block_scale_factor: 1 + norm_eps: 1e-05 + norm_num_groups: 32 + sample_size: 64 + use_motion_module: true #!!! 这边test iclight的时候可以不用motion module 即False + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: false + motion_module_decoder_only: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 +text_model: + target: modules.openclip.modules.FrozenCLIPEmbedder + params: + freeze: true \ No newline at end of file diff --git a/configs/instruct_v2v_ic_inference_text.yaml b/configs/instruct_v2v_ic_inference_text.yaml new file mode 100644 index 0000000000000000000000000000000000000000..36b96da5d44d9b857bb5fec90cc520970b3035a4 --- /dev/null +++ b/configs/instruct_v2v_ic_inference_text.yaml @@ -0,0 +1,79 @@ +diffusion: + target: pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporalText + params: + beta_schedule_args: + beta_schedule: scaled_linear + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + clip_sample: false + thresholding: false + prediction_type: epsilon + loss_fn: l2 + optim_args: + lr: 1e-5 + unet_init_weights: #! 注意一下, 完全可以从iv2v的ckpt开始train + - unet/diffusion_pytorch_model.safetensors # iclight, unet, sf tensor + - pretrained_models/Motion_Module/mm_sd_v15.ckpt # motion module, 推测加载的是animatediff的 + - pretrained_models/iclight/iclight_sd15_fc.safetensors # iclight lora weights + base_path: /mnt/petrelfs/fangye/.cache/huggingface/hub/models--stablediffusionapi--realistic-vision-v51/snapshots/19e3643d7d963c156d01537188ec08f0b79a514a + # vae_init_weights: pretrained_models/instruct_pix2pix/vqvae.ckpt + # text_model_init_weights: pretrained_models/instruct_pix2pix/text.ckpt #! 这两个可以直接设为None, 从from_pretrained中加载 + scale_factor: 0.18215 + guidance_scale: 5 # not used + ddim_sampling_steps: 20 + text_cfg: 7.5 + img_cfg: 1.2 + cond_image_dropout: 0.1 + prompt_type: edit_prompt +unet: + target: modules.video_unet_temporal.unet.UNet3DConditionModel + params: + in_channels: 4 #! change:8->12 iclight 改为12 注意一下... + out_channels: 4 + act_fn: silu + attention_head_dim: 8 + block_out_channels: + - 320 + - 640 + - 1280 + - 1280 + cross_attention_dim: 768 + down_block_types: + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - DownBlock3D + up_block_types: + - UpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + downsample_padding: 1 + layers_per_block: 2 + mid_block_scale_factor: 1 + norm_eps: 1e-05 + norm_num_groups: 32 + sample_size: 64 + use_motion_module: true #!!! 这边test iclight的时候可以不用motion module 即False + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: false + motion_module_decoder_only: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 +text_model: + target: modules.openclip.modules.FrozenCLIPEmbedder + params: + freeze: true \ No newline at end of file diff --git a/configs/instruct_v2v_ic_pexels.yaml b/configs/instruct_v2v_ic_pexels.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c8277248ff5a411b61845dff0ed030ab46557871 --- /dev/null +++ b/configs/instruct_v2v_ic_pexels.yaml @@ -0,0 +1,133 @@ +expt_dir: experiments +expt_name: instruct_v2v_ic_pexels_text_bgdrop_0.3_trystepckpt +trainer_args: + max_epochs: 1 + accelerator: "gpu" + devices: [0] #! change to get more cards + limit_train_batches: 2048 + limit_val_batches: 1 #! 这边限制了每个epoch只跑多少个batch的validation + # strategy: "ddp" + strategy: "deepspeed_stage_2" + # autotune_only_on_rank_zero: true # 确保只有一个进程执行调优表操作 + accumulate_grad_batches: 1 #! 注意一下这个值 + check_val_every_n_epoch: 1 #! check一下这个值是不是和记录有关。。。 + # precision: 16 # 启用半精度 (FP16) +diffusion: + target: pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporal + params: + beta_schedule_args: + beta_schedule: scaled_linear + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + clip_sample: false + thresholding: false + prediction_type: epsilon + loss_fn: l2 + optim_args: + lr: 1e-5 + unet_init_weights: #! 注意一下, 完全可以从iv2v的ckpt开始train + - unet/diffusion_pytorch_model.safetensors # iclight, unet, sf tensor + - pretrained_models/Motion_Module/mm_sd_v15.ckpt # motion module, 推测加载的是animatediff的 + - pretrained_models/iclight/iclight_sd15_fbc.safetensors # iclight lora weights + base_path: /mnt/petrelfs/fangye/.cache/huggingface/hub/models--stablediffusionapi--realistic-vision-v51/snapshots/19e3643d7d963c156d01537188ec08f0b79a514a + # vae_init_weights: pretrained_models/instruct_pix2pix/vqvae.ckpt + # text_model_init_weights: pretrained_models/instruct_pix2pix/text.ckpt #! 这两个可以直接设为None, 从from_pretrained中加载 + scale_factor: 0.18215 + guidance_scale: 5 # not used + ddim_sampling_steps: 20 + text_cfg: 7.5 + img_cfg: 1.2 + cond_image_dropout: 0.1 + cond_text_dropout: 0.1 + prompt_type: edit_prompt +unet: + target: modules.video_unet_temporal.unet.UNet3DConditionModel + params: + in_channels: 4 #! change:8->12 iclight 改为12 注意一下... + out_channels: 4 + act_fn: silu + attention_head_dim: 8 + block_out_channels: + - 320 + - 640 + - 1280 + - 1280 + cross_attention_dim: 768 + down_block_types: + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - DownBlock3D + up_block_types: + - UpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + downsample_padding: 1 + layers_per_block: 2 + mid_block_scale_factor: 1 + norm_eps: 1e-05 + norm_num_groups: 32 + sample_size: 64 + use_motion_module: true #!!! 这边test iclight的时候可以不用motion module 即False + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: false + motion_module_decoder_only: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 +text_model: + target: modules.openclip.modules.FrozenCLIPEmbedder + params: + freeze: true +data: + batch_size: 1 + val_batch_size: 1 + train: + target: dataset.videoP2P.VideoPromptToPromptMotionAugPexels + params: # 注意修改一下training的路径,和相关加载的代码, 比如说没有meta.yaml这些参数怎么搞 + root_dirs: + - /mnt/petrelfs/fangye/test/instruct-video-to-video_1019/data_train_pexels/rmbg_data + num_frames: 16 + zoom_ratio: 0.2 + max_zoom: 1.25 + translation_ratio: 0.7 + translation_range: [0, 0.2] + is_train: True + val: + target: dataset.videoP2P.VideoPromptToPromptMotionAugPexels + params: + root_dirs: + - /mnt/petrelfs/fangye/test/instruct-video-to-video_1019/data_train_pexels/rmbg_data + num_frames: 16 + zoom_ratio: 0.2 + max_zoom: 1.25 + translation_ratio: 0.7 + translation_range: [0, 0.2] +callbacks: + - target: pytorch_lightning.callbacks.ModelCheckpoint + params: + dirpath: "${expt_dir}/${expt_name}" + filename: "{step:06d}" + every_n_train_steps: 10 + save_last: True + # every_n_train_steps: 10 + - target: callbacks.instruct_p2p_video.InstructP2PLogger + params: + max_num_images: 1 + expt_name: instruct_v2v_ic_pexels_text_bgdrop_0.3_trystepckpt + # accumulate_grad_batches: 128 + require_wandb: true + - target: pytorch_lightning.callbacks.DeviceStatsMonitor diff --git a/configs/instruct_v2v_ic_pexels_hdr.yaml b/configs/instruct_v2v_ic_pexels_hdr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..1c41b79795b7ace02e6ed17f9703069d922602ca --- /dev/null +++ b/configs/instruct_v2v_ic_pexels_hdr.yaml @@ -0,0 +1,147 @@ +expt_dir: experiments +expt_name: instruct_v2v_ic_pexels_text_hdr_test_lr0.5_aug_lossc_fix_bs1 #! 注意传入log里面, 不要每次修改 +trainer_args: + max_epochs: 10 + accelerator: "gpu" + devices: [0,1,2,3,4,5,6,7] #! change to get more cards + limit_train_batches: 2048 + limit_val_batches: 3 #! 这边限制了每个epoch只跑多少个batch的validation + # strategy: "ddp" + strategy: "deepspeed_stage_2" + # autotune_only_on_rank_zero: true # 确保只有一个进程执行调优表操作 + accumulate_grad_batches: 128 #! 注意一下这个值 256->128 + check_val_every_n_epoch: 1 #! check一下这个值是不是和记录有关。。。 + # precision: 16 # 启用半精度 (FP16) +diffusion: + target: pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporalText + params: + beta_schedule_args: + beta_schedule: scaled_linear + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + clip_sample: false + thresholding: false + prediction_type: epsilon + loss_fn: l2 + optim_args: + lr: 1e-5 #! 原来是1e-5 + unet_init_weights: #! 注意一下, 完全可以从iv2v的ckpt开始train + - unet/diffusion_pytorch_model.safetensors # iclight, unet, sf tensor + - pretrained_models/Motion_Module/mm_sd_v15.ckpt # motion module, 推测加载的是animatediff的 + - pretrained_models/iclight/iclight_sd15_fc.safetensors # iclight lora weights + base_path: /mnt/petrelfs/fangye/.cache/huggingface/hub/models--stablediffusionapi--realistic-vision-v51/snapshots/19e3643d7d963c156d01537188ec08f0b79a514a + # vae_init_weights: pretrained_models/instruct_pix2pix/vqvae.ckpt + # text_model_init_weights: pretrained_models/instruct_pix2pix/text.ckpt #! 这两个可以直接设为None, 从from_pretrained中加载 + scale_factor: 0.18215 + guidance_scale: 5 # not used + ddim_sampling_steps: 20 + text_cfg: 7.5 + img_cfg: 1.2 + hdr_cfg: 7.5 + cond_image_dropout: 0.1 + cond_text_dropout: 0.1 + cond_hdr_dropout: 0.1 + ic_condition: fg + hdr_train: True + prompt_type: edit_prompt +unet: + target: modules.video_unet_temporal.unet.UNet3DConditionModel + params: + in_channels: 4 #! change:8->12 iclight 改为12 注意一下... + out_channels: 4 + act_fn: silu + attention_head_dim: 8 + block_out_channels: + - 320 + - 640 + - 1280 + - 1280 + cross_attention_dim: 768 + down_block_types: + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - DownBlock3D + up_block_types: + - UpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + downsample_padding: 1 + layers_per_block: 2 + mid_block_scale_factor: 1 + norm_eps: 1e-05 + norm_num_groups: 32 + sample_size: 64 + use_motion_module: true #!!! 这边test iclight的时候可以不用motion module 即False + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: false + motion_module_decoder_only: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 +text_model: + target: modules.openclip.modules.FrozenCLIPEmbedder + params: + freeze: true +data: + batch_size: 1 + val_batch_size: 1 + train: + target: dataset.videoP2P.VideoPromptToPromptMotionAugPexelsHDR + params: # 注意修改一下training的路径,和相关加载的代码, 比如说没有meta.yaml这些参数怎么搞 + root_dirs: #! 注意root_dirs已经更改 + # - /mnt/petrelfs/fangye/test/instruct-video-to-video_1019/data_train_pexels/rmbg_data + - /mnt/hwfile/mllm/sunzeyi/iclight_video/rendered_data_rgb_fixlast + hdr_dir: /mnt/hwfile/mllm/sunzeyi/iclight_video/haven_hdr_rgb + num_frames: 16 + zoom_ratio: 0.2 + max_zoom: 1.25 + translation_ratio: 0.7 + translation_range: [0, 0.2] + is_train: True + ic_condition: fg + val: + target: dataset.videoP2P.VideoPromptToPromptMotionAugPexelsHDR + params: + root_dirs: + # - /mnt/petrelfs/fangye/test/instruct-video-to-video_1019/data_train_pexels/rmbg_data + - /mnt/hwfile/mllm/sunzeyi/iclight_video/rendered_data_rgb_fixlast + hdr_dir: /mnt/hwfile/mllm/sunzeyi/iclight_video/haven_hdr_rgb + num_frames: 16 + zoom_ratio: 0.2 + max_zoom: 1.25 + translation_ratio: 0.7 + translation_range: [0, 0.2] + ic_condition: fg +callbacks: + - target: pytorch_lightning.callbacks.ModelCheckpoint + params: + dirpath: "${expt_dir}/${expt_name}" + # filename: "{epoch:04d}" + filename: "{step:06d}" + every_n_train_steps: 1 + save_last: false + # monitor: epoch + # mode: max + # save_top_k: 3 + # save_last: false + - target: callbacks.instruct_p2p_video.InstructP2PLogger + params: + max_num_images: 1 + expt_name: instruct_v2v_ic_pexels_text_hdr_test_lr0.5_aug_lossc_fix_bs1 + # accumulate_grad_batches: 128 + require_wandb: true + - target: pytorch_lightning.callbacks.DeviceStatsMonitor \ No newline at end of file diff --git a/configs/instruct_v2v_ic_pexels_text.yaml b/configs/instruct_v2v_ic_pexels_text.yaml new file mode 100644 index 0000000000000000000000000000000000000000..25d1bf156b7df42e02ee994d48edf20da78bc950 --- /dev/null +++ b/configs/instruct_v2v_ic_pexels_text.yaml @@ -0,0 +1,137 @@ +expt_dir: experiments +expt_name: instruct_v2v_ic_pexels_text_fg #! 注意传入log里面, 不要每次修改 +trainer_args: + max_epochs: 5 + accelerator: "gpu" + devices: [0] #! change to get more cards + limit_train_batches: 2048 + limit_val_batches: 1 #! 这边限制了每个epoch只跑多少个batch的validation + # strategy: "ddp" + strategy: "deepspeed_stage_2" + # autotune_only_on_rank_zero: true # 确保只有一个进程执行调优表操作 + accumulate_grad_batches: 256 #! 注意一下这个值 + check_val_every_n_epoch: 1 #! check一下这个值是不是和记录有关。。。 + # precision: 16 # 启用半精度 (FP16) +diffusion: + target: pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporalText + params: + beta_schedule_args: + beta_schedule: scaled_linear + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + clip_sample: false + thresholding: false + prediction_type: epsilon + loss_fn: l2 + optim_args: + lr: 1e-5 + unet_init_weights: #! 注意一下, 完全可以从iv2v的ckpt开始train + - unet/diffusion_pytorch_model.safetensors # iclight, unet, sf tensor + - pretrained_models/Motion_Module/mm_sd_v15.ckpt # motion module, 推测加载的是animatediff的 + - pretrained_models/iclight/iclight_sd15_fc.safetensors # iclight lora weights + base_path: /mnt/petrelfs/fangye/.cache/huggingface/hub/models--stablediffusionapi--realistic-vision-v51/snapshots/19e3643d7d963c156d01537188ec08f0b79a514a + # vae_init_weights: pretrained_models/instruct_pix2pix/vqvae.ckpt + # text_model_init_weights: pretrained_models/instruct_pix2pix/text.ckpt #! 这两个可以直接设为None, 从from_pretrained中加载 + scale_factor: 0.18215 + guidance_scale: 5 # not used + ddim_sampling_steps: 20 + text_cfg: 7.5 + img_cfg: 1.2 + cond_image_dropout: 0.1 + cond_text_dropout: 0.075 + ic_condition: fg + prompt_type: edit_prompt +unet: + target: modules.video_unet_temporal.unet.UNet3DConditionModel + params: + in_channels: 4 #! change:8->12 iclight 改为12 注意一下... + out_channels: 4 + act_fn: silu + attention_head_dim: 8 + block_out_channels: + - 320 + - 640 + - 1280 + - 1280 + cross_attention_dim: 768 + down_block_types: + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - DownBlock3D + up_block_types: + - UpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + downsample_padding: 1 + layers_per_block: 2 + mid_block_scale_factor: 1 + norm_eps: 1e-05 + norm_num_groups: 32 + sample_size: 64 + use_motion_module: true #!!! 这边test iclight的时候可以不用motion module 即False + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: false + motion_module_decoder_only: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 +text_model: + target: modules.openclip.modules.FrozenCLIPEmbedder + params: + freeze: true +data: + batch_size: 1 + val_batch_size: 1 + train: + target: dataset.videoP2P.VideoPromptToPromptMotionAugPexels + params: # 注意修改一下training的路径,和相关加载的代码, 比如说没有meta.yaml这些参数怎么搞 + root_dirs: + - /mnt/petrelfs/fangye/test/instruct-video-to-video_1019/data_train_pexels/rmbg_data + num_frames: 16 + zoom_ratio: 0.2 + max_zoom: 1.25 + translation_ratio: 0.7 + translation_range: [0, 0.2] + is_train: True + ic_condition: fg + val: + target: dataset.videoP2P.VideoPromptToPromptMotionAugPexels + params: + root_dirs: + - /mnt/petrelfs/fangye/test/instruct-video-to-video_1019/data_train_pexels/rmbg_data + num_frames: 16 + zoom_ratio: 0.2 + max_zoom: 1.25 + translation_ratio: 0.7 + translation_range: [0, 0.2] + ic_condition: fg +callbacks: + - target: pytorch_lightning.callbacks.ModelCheckpoint + params: + dirpath: "${expt_dir}/${expt_name}" + filename: "{epoch:04d}" + monitor: epoch + mode: max + save_top_k: 5 + save_last: false + - target: callbacks.instruct_p2p_video.InstructP2PLogger + params: + max_num_images: 1 + expt_name: instruct_v2v_ic_pexels_text_fg + # accumulate_grad_batches: 128 + require_wandb: true + - target: pytorch_lightning.callbacks.DeviceStatsMonitor \ No newline at end of file diff --git a/configs/instruct_v2v_ic_pexels_text_hdr.yaml b/configs/instruct_v2v_ic_pexels_text_hdr.yaml new file mode 100644 index 0000000000000000000000000000000000000000..25d1bf156b7df42e02ee994d48edf20da78bc950 --- /dev/null +++ b/configs/instruct_v2v_ic_pexels_text_hdr.yaml @@ -0,0 +1,137 @@ +expt_dir: experiments +expt_name: instruct_v2v_ic_pexels_text_fg #! 注意传入log里面, 不要每次修改 +trainer_args: + max_epochs: 5 + accelerator: "gpu" + devices: [0] #! change to get more cards + limit_train_batches: 2048 + limit_val_batches: 1 #! 这边限制了每个epoch只跑多少个batch的validation + # strategy: "ddp" + strategy: "deepspeed_stage_2" + # autotune_only_on_rank_zero: true # 确保只有一个进程执行调优表操作 + accumulate_grad_batches: 256 #! 注意一下这个值 + check_val_every_n_epoch: 1 #! check一下这个值是不是和记录有关。。。 + # precision: 16 # 启用半精度 (FP16) +diffusion: + target: pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporalText + params: + beta_schedule_args: + beta_schedule: scaled_linear + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + clip_sample: false + thresholding: false + prediction_type: epsilon + loss_fn: l2 + optim_args: + lr: 1e-5 + unet_init_weights: #! 注意一下, 完全可以从iv2v的ckpt开始train + - unet/diffusion_pytorch_model.safetensors # iclight, unet, sf tensor + - pretrained_models/Motion_Module/mm_sd_v15.ckpt # motion module, 推测加载的是animatediff的 + - pretrained_models/iclight/iclight_sd15_fc.safetensors # iclight lora weights + base_path: /mnt/petrelfs/fangye/.cache/huggingface/hub/models--stablediffusionapi--realistic-vision-v51/snapshots/19e3643d7d963c156d01537188ec08f0b79a514a + # vae_init_weights: pretrained_models/instruct_pix2pix/vqvae.ckpt + # text_model_init_weights: pretrained_models/instruct_pix2pix/text.ckpt #! 这两个可以直接设为None, 从from_pretrained中加载 + scale_factor: 0.18215 + guidance_scale: 5 # not used + ddim_sampling_steps: 20 + text_cfg: 7.5 + img_cfg: 1.2 + cond_image_dropout: 0.1 + cond_text_dropout: 0.075 + ic_condition: fg + prompt_type: edit_prompt +unet: + target: modules.video_unet_temporal.unet.UNet3DConditionModel + params: + in_channels: 4 #! change:8->12 iclight 改为12 注意一下... + out_channels: 4 + act_fn: silu + attention_head_dim: 8 + block_out_channels: + - 320 + - 640 + - 1280 + - 1280 + cross_attention_dim: 768 + down_block_types: + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - DownBlock3D + up_block_types: + - UpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + downsample_padding: 1 + layers_per_block: 2 + mid_block_scale_factor: 1 + norm_eps: 1e-05 + norm_num_groups: 32 + sample_size: 64 + use_motion_module: true #!!! 这边test iclight的时候可以不用motion module 即False + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: false + motion_module_decoder_only: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 +text_model: + target: modules.openclip.modules.FrozenCLIPEmbedder + params: + freeze: true +data: + batch_size: 1 + val_batch_size: 1 + train: + target: dataset.videoP2P.VideoPromptToPromptMotionAugPexels + params: # 注意修改一下training的路径,和相关加载的代码, 比如说没有meta.yaml这些参数怎么搞 + root_dirs: + - /mnt/petrelfs/fangye/test/instruct-video-to-video_1019/data_train_pexels/rmbg_data + num_frames: 16 + zoom_ratio: 0.2 + max_zoom: 1.25 + translation_ratio: 0.7 + translation_range: [0, 0.2] + is_train: True + ic_condition: fg + val: + target: dataset.videoP2P.VideoPromptToPromptMotionAugPexels + params: + root_dirs: + - /mnt/petrelfs/fangye/test/instruct-video-to-video_1019/data_train_pexels/rmbg_data + num_frames: 16 + zoom_ratio: 0.2 + max_zoom: 1.25 + translation_ratio: 0.7 + translation_range: [0, 0.2] + ic_condition: fg +callbacks: + - target: pytorch_lightning.callbacks.ModelCheckpoint + params: + dirpath: "${expt_dir}/${expt_name}" + filename: "{epoch:04d}" + monitor: epoch + mode: max + save_top_k: 5 + save_last: false + - target: callbacks.instruct_p2p_video.InstructP2PLogger + params: + max_num_images: 1 + expt_name: instruct_v2v_ic_pexels_text_fg + # accumulate_grad_batches: 128 + require_wandb: true + - target: pytorch_lightning.callbacks.DeviceStatsMonitor \ No newline at end of file diff --git a/configs/instruct_v2v_ic_test.yaml b/configs/instruct_v2v_ic_test.yaml new file mode 100644 index 0000000000000000000000000000000000000000..93adb2f62bc6e8141baa6ecf066f97b03afbcf43 --- /dev/null +++ b/configs/instruct_v2v_ic_test.yaml @@ -0,0 +1,132 @@ +expt_dir: experiments +expt_name: instruct_v2v_ic +trainer_args: + max_epochs: 5 + accelerator: "gpu" + devices: [0,1,2,3] + limit_train_batches: 2048 + limit_val_batches: 1 #! 这边限制了每个epoch只跑多少个batch的validation + # strategy: "ddp" + strategy: "deepspeed_stage_2" + # autotune_only_on_rank_zero: true # 确保只有一个进程执行调优表操作 + accumulate_grad_batches: 32 #! 注意一下这个值 + check_val_every_n_epoch: 1 #! check一下这个值是不是和记录有关。。。 + # precision: 16 # 启用半精度 (FP16) +diffusion: + target: pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporal + params: + beta_schedule_args: + beta_schedule: scaled_linear + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + clip_sample: false + thresholding: false + prediction_type: epsilon + loss_fn: l2 + optim_args: + lr: 1e-5 + unet_init_weights: #! 注意一下, 完全可以从iv2v的ckpt开始train + - unet/diffusion_pytorch_model.safetensors # iclight, unet, sf tensor + - pretrained_models/Motion_Module/mm_sd_v15.ckpt # motion module, 推测加载的是animatediff的 + - pretrained_models/iclight/iclight_sd15_fbc.safetensors # iclight lora weights + base_path: /mnt/petrelfs/fangye/.cache/huggingface/hub/models--stablediffusionapi--realistic-vision-v51/snapshots/19e3643d7d963c156d01537188ec08f0b79a514a + # vae_init_weights: pretrained_models/instruct_pix2pix/vqvae.ckpt + # text_model_init_weights: pretrained_models/instruct_pix2pix/text.ckpt #! 这两个可以直接设为None, 从from_pretrained中加载 + scale_factor: 0.18215 + guidance_scale: 5 # not used + ddim_sampling_steps: 20 + text_cfg: 7.5 + img_cfg: 1.2 + cond_image_dropout: 0.1 + prompt_type: edit_prompt +unet: + target: modules.video_unet_temporal.unet.UNet3DConditionModel + params: + in_channels: 4 #! change:8->12 iclight 改为12 注意一下... + out_channels: 4 + act_fn: silu + attention_head_dim: 8 + block_out_channels: + - 320 + - 640 + - 1280 + - 1280 + cross_attention_dim: 768 + down_block_types: + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - DownBlock3D + up_block_types: + - UpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + downsample_padding: 1 + layers_per_block: 2 + mid_block_scale_factor: 1 + norm_eps: 1e-05 + norm_num_groups: 32 + sample_size: 64 + use_motion_module: true #!!! 这边test iclight的时候可以不用motion module 即False + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: false + motion_module_decoder_only: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 +text_model: + target: modules.openclip.modules.FrozenCLIPEmbedder + params: + freeze: true +data: + batch_size: 1 + val_batch_size: 1 + train: + target: dataset.videoP2P.VideoPromptToPromptMotionAug + params: # 注意修改一下training的路径,和相关加载的代码, 比如说没有meta.yaml这些参数怎么搞 + root_dirs: + - /mnt/petrelfs/fangye/test/instruct-video-to-video_1019/data_train_v2 + num_frames: 16 + zoom_ratio: 0.2 + max_zoom: 1.25 + translation_ratio: 0.7 + translation_range: [0, 0.2] + is_train: True + val: + target: dataset.videoP2P.VideoPromptToPromptMotionAug + params: + root_dirs: + - data_train + num_frames: 16 + zoom_ratio: 0.2 + max_zoom: 1.25 + translation_ratio: 0.7 + translation_range: [0, 0.2] +callbacks: + - target: pytorch_lightning.callbacks.ModelCheckpoint + params: + dirpath: "${expt_dir}/${expt_name}" + filename: "{epoch:04d}" + monitor: epoch + mode: max + save_top_k: 5 + save_last: true + - target: callbacks.instruct_p2p_video.InstructP2PLogger + params: + max_num_images: 1 + # accumulate_grad_batches: 128 + require_wandb: true + - target: pytorch_lightning.callbacks.DeviceStatsMonitor \ No newline at end of file diff --git a/configs/instruct_v2v_inference.yaml b/configs/instruct_v2v_inference.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e677d8b54c800264def94dd7f0172838a6ed113c --- /dev/null +++ b/configs/instruct_v2v_inference.yaml @@ -0,0 +1,98 @@ +diffusion: + target: pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporal + params: + beta_schedule_args: + beta_schedule: scaled_linear + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + clip_sample: false + thresholding: false + prediction_type: epsilon + loss_fn: l2 + unet_init_weights: + - pretrained_models/instruct_pix2pix/diffusion_pytorch_model.bin + - pretrained_models/Motion_Module/mm_sd_v15.ckpt + vae_init_weights: pretrained_models/instruct_pix2pix/vqvae.ckpt + text_model_init_weights: pretrained_models/instruct_pix2pix/text.ckpt + optim_args: + lr: 1e-5 + scale_factor: 0.18215 + guidance_scale: 5 # not used + ddim_sampling_steps: 20 + text_cfg: 7.5 + img_cfg: 1.2 + cond_image_dropout: 0.1 + prompt_type: edit_prompt +unet: + target: modules.video_unet_temporal.unet.UNet3DConditionModel + params: + in_channels: 8 + out_channels: 4 + act_fn: silu + attention_head_dim: 8 + block_out_channels: + - 320 + - 640 + - 1280 + - 1280 + cross_attention_dim: 768 + down_block_types: + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - DownBlock3D + up_block_types: + - UpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + downsample_padding: 1 + layers_per_block: 2 + mid_block_scale_factor: 1 + norm_eps: 1e-05 + norm_num_groups: 32 + sample_size: 64 + use_motion_module: true + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: false + motion_module_decoder_only: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 +vae: + target: modules.kl_autoencoder.autoencoder.AutoencoderKL + params: + embed_dim: 4 + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity +text_model: + target: modules.openclip.modules.FrozenCLIPEmbedder + params: + freeze: true \ No newline at end of file diff --git a/configs/instruct_v2v_ori.yaml b/configs/instruct_v2v_ori.yaml new file mode 100644 index 0000000000000000000000000000000000000000..ac35a082a4f680f21628f8e6a899d0ddce4af3da --- /dev/null +++ b/configs/instruct_v2v_ori.yaml @@ -0,0 +1,147 @@ +expt_dir: experiments +expt_name: instruct_v2v +trainer_args: + max_epochs: 10 + accelerator: "gpu" + devices: [0] + limit_train_batches: 2048 + limit_val_batches: 1 + # strategy: "ddp" + strategy: "deepspeed_stage_2" + accumulate_grad_batches: 256 + check_val_every_n_epoch: 5 +diffusion: + target: pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporal + params: + beta_schedule_args: + beta_schedule: scaled_linear + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + clip_sample: false + thresholding: false + prediction_type: epsilon + loss_fn: l2 + optim_args: + lr: 1e-5 + unet_init_weights: + - pretrained_models/instruct_pix2pix/diffusion_pytorch_model.bin + - pretrained_models/Motion_Module/mm_sd_v15.ckpt + vae_init_weights: pretrained_models/instruct_pix2pix/vqvae.ckpt + text_model_init_weights: pretrained_models/instruct_pix2pix/text.ckpt + scale_factor: 0.18215 + guidance_scale: 5 # not used + ddim_sampling_steps: 20 + text_cfg: 7.5 + img_cfg: 1.2 + cond_image_dropout: 0.1 + prompt_type: edit_prompt +unet: + target: modules.video_unet_temporal.unet.UNet3DConditionModel + params: + in_channels: 8 + out_channels: 4 + act_fn: silu + attention_head_dim: 8 + block_out_channels: + - 320 + - 640 + - 1280 + - 1280 + cross_attention_dim: 768 + down_block_types: + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - DownBlock3D + up_block_types: + - UpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + downsample_padding: 1 + layers_per_block: 2 + mid_block_scale_factor: 1 + norm_eps: 1e-05 + norm_num_groups: 32 + sample_size: 64 + use_motion_module: true + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: false + motion_module_decoder_only: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 +vae: + target: modules.kl_autoencoder.autoencoder.AutoencoderKL + params: + embed_dim: 4 + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity +text_model: + target: modules.openclip.modules.FrozenCLIPEmbedder + params: + freeze: true +data: + batch_size: 1 + val_batch_size: 1 + train: + target: dataset.videoP2P.VideoPromptToPromptMotionAug + params: + root_dirs: + - video_ptp/raw_generated + - video_ptp/raw_generated_webvid + num_frames: 16 + zoom_ratio: 0.2 + max_zoom: 1.25 + translation_ratio: 0.7 + translation_range: [0, 0.2] + val: + target: dataset.videoP2P.VideoPromptToPromptMotionAug + params: + root_dirs: + - video_ptp/raw_generated + num_frames: 16 + zoom_ratio: 0.2 + max_zoom: 1.25 + translation_ratio: 0.7 + translation_range: [0, 0.2] +callbacks: + - target: pytorch_lightning.callbacks.ModelCheckpoint + params: + dirpath: "${expt_dir}/${expt_name}" + filename: "{epoch:04d}" + monitor: epoch + mode: max + save_top_k: 5 + save_last: true + - target: callbacks.instruct_p2p_video.InstructP2PLogger + params: + max_num_images: 1 + require_wandb: true \ No newline at end of file diff --git a/configs/test_textmodel.yaml b/configs/test_textmodel.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3e608dc1fb2ba3d22da909fab685fa13ae35a1ff --- /dev/null +++ b/configs/test_textmodel.yaml @@ -0,0 +1,7 @@ +diffusion: + params: + base_path: /home/fy/.cache/huggingface/hub/models--stablediffusionapi--realistic-vision-v51/snapshots/19e3643d7d963c156d01537188ec08f0b79a514a +text_model: + target: modules.openclip.modules.FrozenCLIPEmbedder + params: + freeze: true \ No newline at end of file diff --git a/configs/test_vae.yaml b/configs/test_vae.yaml new file mode 100644 index 0000000000000000000000000000000000000000..3842df52764a4e180c009ab3a9fa23c6e7009fcf --- /dev/null +++ b/configs/test_vae.yaml @@ -0,0 +1,21 @@ +vae: + target: modules.kl_autoencoder.autoencoder.AutoencoderKL + params: + embed_dim: 4 + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 #先暂时256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity \ No newline at end of file diff --git a/configs/test_vae_ori.yaml b/configs/test_vae_ori.yaml new file mode 100644 index 0000000000000000000000000000000000000000..282fe342b2155c90e358e12f98aea1ab8d109230 --- /dev/null +++ b/configs/test_vae_ori.yaml @@ -0,0 +1,28 @@ +diffusion: + params: + base_path: /home/fy/.cache/huggingface/hub/models--stablediffusionapi--realistic-vision-v51/snapshots/19e3643d7d963c156d01537188ec08f0b79a514a +text_model: + target: modules.openclip.modules.FrozenCLIPEmbedder + params: + freeze: true +vae: + target: modules.kl_autoencoder.autoencoder.AutoencoderKL + params: + embed_dim: 4 + ddconfig: + double_z: true + z_channels: 4 + resolution: 256 #先暂时256 + in_channels: 3 + out_ch: 3 + ch: 128 + ch_mult: + - 1 + - 2 + - 4 + - 4 + num_res_blocks: 2 + attn_resolutions: [] + dropout: 0.0 + lossconfig: + target: torch.nn.Identity \ No newline at end of file diff --git a/configs/tmp_ic.yaml b/configs/tmp_ic.yaml new file mode 100644 index 0000000000000000000000000000000000000000..199485d033ab4ef687b8111465f98b4eb8f5b04a --- /dev/null +++ b/configs/tmp_ic.yaml @@ -0,0 +1,130 @@ +expt_dir: experiments +expt_name: instruct_v2v_ic +trainer_args: + max_epochs: 10 + accelerator: "gpu" + devices: [0] + limit_train_batches: 2048 + limit_val_batches: 5 #! 这边限制了每个epoch只跑多少个batch的validation + # strategy: "ddp" + strategy: "deepspeed_stage_2" + accumulate_grad_batches: 128 #! 注意一下这个值 + check_val_every_n_epoch: 1 #! check一下这个值是不是和记录有关。。。 +diffusion: + target: pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporal + params: + beta_schedule_args: + beta_schedule: scaled_linear + num_train_timesteps: 1000 + beta_start: 0.00085 + beta_end: 0.012 + clip_sample: false + thresholding: false + prediction_type: epsilon + loss_fn: l2 + optim_args: + lr: 1e-5 + unet_init_weights: #! 注意一下, 完全可以从iv2v的ckpt开始train + - unet/diffusion_pytorch_model.safetensors # iclight, unet, sf tensor + - pretrained_models/Motion_Module/mm_sd_v15.ckpt # motion module, 推测加载的是animatediff的 + - /mnt/petrelfs/fangye/IC-Light/models/iclight_sd15_fbc.safetensors # iclight lora weights + base_path: /mnt/petrelfs/fangye/.cache/huggingface/hub/models--stablediffusionapi--realistic-vision-v51/snapshots/19e3643d7d963c156d01537188ec08f0b79a514a + # vae_init_weights: pretrained_models/instruct_pix2pix/vqvae.ckpt + # text_model_init_weights: pretrained_models/instruct_pix2pix/text.ckpt #! 这两个可以直接设为None, 从from_pretrained中加载 + scale_factor: 0.18215 + guidance_scale: 5 # not used + ddim_sampling_steps: 20 + text_cfg: 7.5 + img_cfg: 1.2 + cond_image_dropout: 0.1 + prompt_type: edit_prompt +unet: + target: modules.video_unet_temporal.unet.UNet3DConditionModel + params: + in_channels: 4 #! change:8->12 iclight 改为12 注意一下... + out_channels: 4 + act_fn: silu + attention_head_dim: 8 + block_out_channels: + - 320 + - 640 + - 1280 + - 1280 + cross_attention_dim: 768 + down_block_types: + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - CrossAttnDownBlock3D + - DownBlock3D + up_block_types: + - UpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + - CrossAttnUpBlock3D + downsample_padding: 1 + layers_per_block: 2 + mid_block_scale_factor: 1 + norm_eps: 1e-05 + norm_num_groups: 32 + sample_size: 64 + use_motion_module: true #! 这边test iclight的时候可以不用motion module 即False + motion_module_resolutions: + - 1 + - 2 + - 4 + - 8 + motion_module_mid_block: false + motion_module_decoder_only: false + motion_module_type: Vanilla + motion_module_kwargs: + num_attention_heads: 8 + num_transformer_block: 1 + attention_block_types: + - Temporal_Self + - Temporal_Self + temporal_position_encoding: true + temporal_position_encoding_max_len: 32 + temporal_attention_dim_div: 1 +text_model: + target: modules.openclip.modules.FrozenCLIPEmbedder + params: + freeze: true +data: + batch_size: 1 + val_batch_size: 1 + train: + target: dataset.videoP2P.VideoPromptToPromptMotionAug + params: #注意修改一下training的路径,和相关加载的代码, 比如说没有meta.yaml这些参数怎么搞 + root_dirs: + - /home/fy/Code/instruct-video-to-video/data_train/Girl + num_frames: 16 + zoom_ratio: 0.2 + max_zoom: 1.25 + translation_ratio: 0.7 + translation_range: [0, 0.2] + is_train: True + val: + target: dataset.videoP2P.VideoPromptToPromptMotionAug + params: + root_dirs: + - /home/fy/Code/instruct-video-to-video/data_train/Girl + num_frames: 16 + zoom_ratio: 0.2 + max_zoom: 1.25 + translation_ratio: 0.7 + translation_range: [0, 0.2] +callbacks: + - target: pytorch_lightning.callbacks.ModelCheckpoint + params: + dirpath: "${expt_dir}/${expt_name}" + filename: "{epoch:04d}" + monitor: epoch + mode: max + save_top_k: 5 + save_last: true + - target: callbacks.instruct_p2p_video.InstructP2PLogger + params: + max_num_images: 1 + # accumulate_grad_batches: 128 + require_wandb: true + - target: pytorch_lightning.callbacks.DeviceStatsMonitor \ No newline at end of file diff --git a/db_examples.py b/db_examples.py new file mode 100644 index 0000000000000000000000000000000000000000..3e4fc10534c24c46dc6f05e79252cb80844d4ad0 --- /dev/null +++ b/db_examples.py @@ -0,0 +1,133 @@ + +bg_samples = [ + 'demo/clean_bg_extracted/22/frames/0000.png', + 'demo/clean_bg_extracted/23/frames/0000.png', + 'demo/clean_bg_extracted/27/frames/0000.png', + 'demo/clean_bg_extracted/33/frames/0000.png', + 'demo/clean_bg_extracted/47/frames/0000.png', + 'demo/clean_bg_extracted/39/frames/0000.png', + 'demo/clean_bg_extracted/59/frames/0000.png', + 'demo/clean_bg_extracted/55/frames/0000.png', + 'demo/clean_bg_extracted/58/frames/0000.png', + 'demo/clean_bg_extracted/57/frames/0000.png', #42 + 'demo/clean_bg_extracted/8/frames/0000.png', + 'demo/clean_bg_extracted/9/frames/0000.png', + 'demo/clean_bg_extracted/10/frames/0000.png', + 'demo/clean_bg_extracted/14/frames/0000.png', + 'demo/clean_bg_extracted/62/frames/0000.png' +] # 准备大概 15 个 background视频 + + +background_conditioned_examples = [ + [ + "demo/clean_fg_extracted/14/cropped_video.mp4", + "demo/clean_bg_extracted/22/cropped_video.mp4", + "beautiful woman, cinematic lighting", + "Use Background Video", + 512, + 768, + 12345, + "static_fg_sync_bg_visualization_fy/14_22_100fps.mp4", + ], + [ + "demo/clean_fg_extracted/14/cropped_video.mp4", + "demo/clean_bg_extracted/55/cropped_video.mp4", + "beautiful woman, cinematic lighting", + "Use Background Video", + 512, + 768, + 12345, + "static_fg_sync_bg_visualization_fy/14_55_100fps.mp4", + ], + [ + "demo/clean_fg_extracted/15/cropped_video.mp4", + "demo/clean_bg_extracted/27/cropped_video.mp4", + "beautiful woman, cinematic lighting", + "Use Background Video", + 512, + 768, + 12345, + "static_fg_sync_bg_visualization_fy/15_27_100fps.mp4", + ], + [ + "demo/clean_fg_extracted/18/cropped_video.mp4", + "demo/clean_bg_extracted/23/cropped_video.mp4", + "beautiful woman, cinematic lighting", + "Use Background Video", + 512, + 768, + 12345, + "static_fg_sync_bg_visualization_fy/18_23_100fps.mp4", + ], + # [ + # "demo/clean_fg_extracted/18/cropped_video.mp4", + # "demo/clean_bg_extracted/33/cropped_video.mp4", + # "beautiful woman, cinematic lighting", + # "Use Background Video", + # 512, + # 768, + # 12345, + # "static_fg_sync_bg_visualization_fy/18_33_100fps.mp4", + # ], + [ + "demo/clean_fg_extracted/22/cropped_video.mp4", + "demo/clean_bg_extracted/39/cropped_video.mp4", + "beautiful woman, cinematic lighting", + "Use Background Video", + 512, + 768, + 12345, + "static_fg_sync_bg_visualization_fy/22_39_100fps.mp4", + ], + # [ + # "demo/clean_fg_extracted/22/cropped_video.mp4", + # "demo/clean_bg_extracted/59/cropped_video.mp4", + # "beautiful woman, cinematic lighting", + # "Use Background Video", + # 512, + # 768, + # 12345, + # "static_fg_sync_bg_visualization_fy/22_59_100fps.mp4", + # ], + [ + "demo/clean_fg_extracted/9/cropped_video.mp4", + "demo/clean_bg_extracted/8/cropped_video.mp4", + "beautiful woman, cinematic lighting", + "Use Background Video", + 512, + 768, + 12345, + "static_fg_sync_bg_visualization_fy/9_8_100fps.mp4", + ], + [ + "demo/clean_fg_extracted/9/cropped_video.mp4", + "demo/clean_bg_extracted/9/cropped_video.mp4", + "beautiful woman, cinematic lighting", + "Use Background Video", + 512, + 768, + 12345, + "static_fg_sync_bg_visualization_fy/9_9_100fps.mp4", + ], + [ + "demo/clean_fg_extracted/9/cropped_video.mp4", + "demo/clean_bg_extracted/10/cropped_video.mp4", + "beautiful woman, cinematic lighting", + "Use Background Video", + 512, + 768, + 12345, + "static_fg_sync_bg_visualization_fy/9_10_100fps.mp4", + ], + # [ + # "demo/clean_fg_extracted/9/cropped_video.mp4", + # "demo/clean_bg_extracted/14/cropped_video.mp4", + # "beautiful woman, cinematic lighting", + # "Use Background Video", + # 512, + # 768, + # 12345, + # "static_fg_sync_bg_visualization_fy/9_14_100fps.mp4", + # ], + +] diff --git a/demo/clean_bg_extracted/10/cropped_video.mp4 b/demo/clean_bg_extracted/10/cropped_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..bcf273717ef3746e3d765bbce4ca326080365a36 Binary files /dev/null and b/demo/clean_bg_extracted/10/cropped_video.mp4 differ diff --git a/demo/clean_bg_extracted/10/frames/0000.png b/demo/clean_bg_extracted/10/frames/0000.png new file mode 100644 index 0000000000000000000000000000000000000000..cc82d4dc31f2d55b7dba46d0324b511dfc3d7801 Binary files /dev/null and b/demo/clean_bg_extracted/10/frames/0000.png differ diff --git a/demo/clean_bg_extracted/14/cropped_video.mp4 b/demo/clean_bg_extracted/14/cropped_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..451ebd3c55894329209a70f512dc70604588a5a4 Binary files /dev/null and b/demo/clean_bg_extracted/14/cropped_video.mp4 differ diff --git a/demo/clean_bg_extracted/14/frames/0000.png b/demo/clean_bg_extracted/14/frames/0000.png new file mode 100644 index 0000000000000000000000000000000000000000..c9e78355e592e7302e1d17a98c56f0f458a77a11 Binary files /dev/null and b/demo/clean_bg_extracted/14/frames/0000.png differ diff --git a/demo/clean_bg_extracted/22/cropped_video.mp4 b/demo/clean_bg_extracted/22/cropped_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..890156c9a411aa9dd58a77a18e5f44cea5dc1bc4 Binary files /dev/null and b/demo/clean_bg_extracted/22/cropped_video.mp4 differ diff --git a/demo/clean_bg_extracted/22/frames/0000.png b/demo/clean_bg_extracted/22/frames/0000.png new file mode 100644 index 0000000000000000000000000000000000000000..3b4309c022f142121878fa680ff78ae878ed6d03 Binary files /dev/null and b/demo/clean_bg_extracted/22/frames/0000.png differ diff --git a/demo/clean_bg_extracted/23/cropped_video.mp4 b/demo/clean_bg_extracted/23/cropped_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..f9fc64c08ed134830410617f2f4acd7bf4d779f6 Binary files /dev/null and b/demo/clean_bg_extracted/23/cropped_video.mp4 differ diff --git a/demo/clean_bg_extracted/23/frames/0000.png b/demo/clean_bg_extracted/23/frames/0000.png new file mode 100644 index 0000000000000000000000000000000000000000..9782f01b6d1dad0d0fde9aefe3ef1037b5490c86 Binary files /dev/null and b/demo/clean_bg_extracted/23/frames/0000.png differ diff --git a/demo/clean_bg_extracted/27/cropped_video.mp4 b/demo/clean_bg_extracted/27/cropped_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..ed8ff45e299e1eec5f0602d0e72d7f73090b4aea Binary files /dev/null and b/demo/clean_bg_extracted/27/cropped_video.mp4 differ diff --git a/demo/clean_bg_extracted/27/frames/0000.png b/demo/clean_bg_extracted/27/frames/0000.png new file mode 100644 index 0000000000000000000000000000000000000000..12313cf5ea53c028a275ce9046fa0ae5cde66066 Binary files /dev/null and b/demo/clean_bg_extracted/27/frames/0000.png differ diff --git a/demo/clean_bg_extracted/33/cropped_video.mp4 b/demo/clean_bg_extracted/33/cropped_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..edeb664704f8442881a8ad1cc163e6553ee9e50c Binary files /dev/null and b/demo/clean_bg_extracted/33/cropped_video.mp4 differ diff --git a/demo/clean_bg_extracted/33/frames/0000.png b/demo/clean_bg_extracted/33/frames/0000.png new file mode 100644 index 0000000000000000000000000000000000000000..7503ea1d18aed3622071782e95bfba6fbcecb6f6 Binary files /dev/null and b/demo/clean_bg_extracted/33/frames/0000.png differ diff --git a/demo/clean_bg_extracted/39/cropped_video.mp4 b/demo/clean_bg_extracted/39/cropped_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..57490e004f7e6f2affb56a28f0fa9bcca0c198af Binary files /dev/null and b/demo/clean_bg_extracted/39/cropped_video.mp4 differ diff --git a/demo/clean_bg_extracted/39/frames/0000.png b/demo/clean_bg_extracted/39/frames/0000.png new file mode 100644 index 0000000000000000000000000000000000000000..b068e136c78092c32225d01a355f559d6de4782b Binary files /dev/null and b/demo/clean_bg_extracted/39/frames/0000.png differ diff --git a/demo/clean_bg_extracted/47/frames/0000.png b/demo/clean_bg_extracted/47/frames/0000.png new file mode 100644 index 0000000000000000000000000000000000000000..179bd23120346b29caf0c248b9421d91a4b8c760 Binary files /dev/null and b/demo/clean_bg_extracted/47/frames/0000.png differ diff --git a/demo/clean_bg_extracted/55/cropped_video.mp4 b/demo/clean_bg_extracted/55/cropped_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..0f6fba93ae26931ac5f85e922abc0e5efe3ba69c Binary files /dev/null and b/demo/clean_bg_extracted/55/cropped_video.mp4 differ diff --git a/demo/clean_bg_extracted/55/frames/0000.png b/demo/clean_bg_extracted/55/frames/0000.png new file mode 100644 index 0000000000000000000000000000000000000000..0804d3bf2ec545f1d98a650cfcc9bec24f78d1aa Binary files /dev/null and b/demo/clean_bg_extracted/55/frames/0000.png differ diff --git a/demo/clean_bg_extracted/57/frames/0000.png b/demo/clean_bg_extracted/57/frames/0000.png new file mode 100644 index 0000000000000000000000000000000000000000..d3aceb2aeecb4e34c9c97ee0aaecf9a39ac443b2 Binary files /dev/null and b/demo/clean_bg_extracted/57/frames/0000.png differ diff --git a/demo/clean_bg_extracted/58/frames/0000.png b/demo/clean_bg_extracted/58/frames/0000.png new file mode 100644 index 0000000000000000000000000000000000000000..b3f1f70106db16aa949459e105486b7ce93bee0a Binary files /dev/null and b/demo/clean_bg_extracted/58/frames/0000.png differ diff --git a/demo/clean_bg_extracted/59/cropped_video.mp4 b/demo/clean_bg_extracted/59/cropped_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..d12d7a8a8a8cfe91cf7e3d3055400d66cf604e4d Binary files /dev/null and b/demo/clean_bg_extracted/59/cropped_video.mp4 differ diff --git a/demo/clean_bg_extracted/59/frames/0000.png b/demo/clean_bg_extracted/59/frames/0000.png new file mode 100644 index 0000000000000000000000000000000000000000..0c4182febbf71a50b8bf4e66430082d37efbb66b Binary files /dev/null and b/demo/clean_bg_extracted/59/frames/0000.png differ diff --git a/demo/clean_bg_extracted/62/frames/0000.png b/demo/clean_bg_extracted/62/frames/0000.png new file mode 100644 index 0000000000000000000000000000000000000000..02947c84cdb907c49eb35ddf9ccb45404a248d21 Binary files /dev/null and b/demo/clean_bg_extracted/62/frames/0000.png differ diff --git a/demo/clean_bg_extracted/8/cropped_video.mp4 b/demo/clean_bg_extracted/8/cropped_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6f07f6534836a6104840fedd9d1a6313bd6f5e4d Binary files /dev/null and b/demo/clean_bg_extracted/8/cropped_video.mp4 differ diff --git a/demo/clean_bg_extracted/8/frames/0000.png b/demo/clean_bg_extracted/8/frames/0000.png new file mode 100644 index 0000000000000000000000000000000000000000..144e155bd11d23cbef05d131c7da3ba77f357c56 Binary files /dev/null and b/demo/clean_bg_extracted/8/frames/0000.png differ diff --git a/demo/clean_bg_extracted/9/cropped_video.mp4 b/demo/clean_bg_extracted/9/cropped_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2c0df6e93110bce59f90e3e716cc469f5ae9891a Binary files /dev/null and b/demo/clean_bg_extracted/9/cropped_video.mp4 differ diff --git a/demo/clean_bg_extracted/9/frames/0000.png b/demo/clean_bg_extracted/9/frames/0000.png new file mode 100644 index 0000000000000000000000000000000000000000..67f0ddfc65f3be34c880c1ffefa752808f3dc3e3 Binary files /dev/null and b/demo/clean_bg_extracted/9/frames/0000.png differ diff --git a/demo/clean_fg_extracted/14/cropped_video.mp4 b/demo/clean_fg_extracted/14/cropped_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..84e13e39814692bd17da75699c2681f645b5780c Binary files /dev/null and b/demo/clean_fg_extracted/14/cropped_video.mp4 differ diff --git a/demo/clean_fg_extracted/15/cropped_video.mp4 b/demo/clean_fg_extracted/15/cropped_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2432826c84a1ec27ba1230cc67c0d0d7a20abbd7 Binary files /dev/null and b/demo/clean_fg_extracted/15/cropped_video.mp4 differ diff --git a/demo/clean_fg_extracted/18/cropped_video.mp4 b/demo/clean_fg_extracted/18/cropped_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6ce61744725b04adc018700cbe1b0d058800817d Binary files /dev/null and b/demo/clean_fg_extracted/18/cropped_video.mp4 differ diff --git a/demo/clean_fg_extracted/22/cropped_video.mp4 b/demo/clean_fg_extracted/22/cropped_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..9b3321d61b3aed496e6382810c907a5f2e3a7b49 Binary files /dev/null and b/demo/clean_fg_extracted/22/cropped_video.mp4 differ diff --git a/demo/clean_fg_extracted/9/cropped_video.mp4 b/demo/clean_fg_extracted/9/cropped_video.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..51d5f12ac9b4b74438dc883e533eea5b28cdc606 Binary files /dev/null and b/demo/clean_fg_extracted/9/cropped_video.mp4 differ diff --git a/demo_utils1.py b/demo_utils1.py new file mode 100644 index 0000000000000000000000000000000000000000..8ca623b8a1173902c5f1dee78d53c106f10a1971 --- /dev/null +++ b/demo_utils1.py @@ -0,0 +1,9 @@ +import os + +# 更改全局临时目录 +new_tmp_dir = "./demo/gradio_bg" +os.makedirs(new_tmp_dir, exist_ok=True) + +os.environ['GRADIO_TEMP_DIR'] = new_tmp_dir + + diff --git a/filtered_params.txt b/filtered_params.txt new file mode 100644 index 0000000000000000000000000000000000000000..3edbe36598885dd0e8e320ab28a5344601d6f533 --- /dev/null +++ b/filtered_params.txt @@ -0,0 +1,560 @@ +unet.down_blocks.0.motion_modules.0.temporal_transformer.norm.weight: torch.Size([320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.norm.bias: torch.Size([320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.proj_in.weight: torch.Size([320, 320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.proj_in.bias: torch.Size([320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight: torch.Size([320, 320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_k.weight: torch.Size([320, 320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_v.weight: torch.Size([320, 320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.weight: torch.Size([320, 320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.bias: torch.Size([320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe: torch.Size([1, 32, 320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_q.weight: torch.Size([320, 320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_k.weight: torch.Size([320, 320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_v.weight: torch.Size([320, 320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.weight: torch.Size([320, 320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.bias: torch.Size([320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe: torch.Size([1, 32, 320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.0.weight: torch.Size([320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.0.bias: torch.Size([320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.1.weight: torch.Size([320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.1.bias: torch.Size([320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.0.proj.weight: torch.Size([2560, 320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.0.proj.bias: torch.Size([2560]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.2.weight: torch.Size([320, 1280]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.2.bias: torch.Size([320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.ff_norm.weight: torch.Size([320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.ff_norm.bias: torch.Size([320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.proj_out.weight: torch.Size([320, 320]) +unet.down_blocks.0.motion_modules.0.temporal_transformer.proj_out.bias: torch.Size([320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.norm.weight: torch.Size([320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.norm.bias: torch.Size([320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.proj_in.weight: torch.Size([320, 320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.proj_in.bias: torch.Size([320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight: torch.Size([320, 320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_k.weight: torch.Size([320, 320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_v.weight: torch.Size([320, 320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.weight: torch.Size([320, 320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.bias: torch.Size([320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe: torch.Size([1, 32, 320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_q.weight: torch.Size([320, 320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_k.weight: torch.Size([320, 320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_v.weight: torch.Size([320, 320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.weight: torch.Size([320, 320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.bias: torch.Size([320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe: torch.Size([1, 32, 320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.0.weight: torch.Size([320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.0.bias: torch.Size([320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.1.weight: torch.Size([320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.1.bias: torch.Size([320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.0.proj.weight: torch.Size([2560, 320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.0.proj.bias: torch.Size([2560]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.2.weight: torch.Size([320, 1280]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.2.bias: torch.Size([320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.ff_norm.weight: torch.Size([320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.ff_norm.bias: torch.Size([320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.proj_out.weight: torch.Size([320, 320]) +unet.down_blocks.0.motion_modules.1.temporal_transformer.proj_out.bias: torch.Size([320]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.norm.weight: torch.Size([640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.norm.bias: torch.Size([640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.proj_in.weight: torch.Size([640, 640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.proj_in.bias: torch.Size([640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight: torch.Size([640, 640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_k.weight: torch.Size([640, 640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_v.weight: torch.Size([640, 640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.weight: torch.Size([640, 640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.bias: torch.Size([640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe: torch.Size([1, 32, 640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_q.weight: torch.Size([640, 640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_k.weight: torch.Size([640, 640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_v.weight: torch.Size([640, 640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.weight: torch.Size([640, 640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.bias: torch.Size([640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe: torch.Size([1, 32, 640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.0.weight: torch.Size([640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.0.bias: torch.Size([640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.1.weight: torch.Size([640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.1.bias: torch.Size([640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.0.proj.weight: torch.Size([5120, 640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.0.proj.bias: torch.Size([5120]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.2.weight: torch.Size([640, 2560]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.2.bias: torch.Size([640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.ff_norm.weight: torch.Size([640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.ff_norm.bias: torch.Size([640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.proj_out.weight: torch.Size([640, 640]) +unet.down_blocks.1.motion_modules.0.temporal_transformer.proj_out.bias: torch.Size([640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.norm.weight: torch.Size([640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.norm.bias: torch.Size([640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.proj_in.weight: torch.Size([640, 640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.proj_in.bias: torch.Size([640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight: torch.Size([640, 640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_k.weight: torch.Size([640, 640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_v.weight: torch.Size([640, 640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.weight: torch.Size([640, 640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.bias: torch.Size([640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe: torch.Size([1, 32, 640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_q.weight: torch.Size([640, 640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_k.weight: torch.Size([640, 640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_v.weight: torch.Size([640, 640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.weight: torch.Size([640, 640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.bias: torch.Size([640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe: torch.Size([1, 32, 640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.0.weight: torch.Size([640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.0.bias: torch.Size([640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.1.weight: torch.Size([640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.1.bias: torch.Size([640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.0.proj.weight: torch.Size([5120, 640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.0.proj.bias: torch.Size([5120]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.2.weight: torch.Size([640, 2560]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.2.bias: torch.Size([640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.ff_norm.weight: torch.Size([640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.ff_norm.bias: torch.Size([640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.proj_out.weight: torch.Size([640, 640]) +unet.down_blocks.1.motion_modules.1.temporal_transformer.proj_out.bias: torch.Size([640]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.norm.weight: torch.Size([1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.norm.bias: torch.Size([1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.proj_in.weight: torch.Size([1280, 1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.proj_in.bias: torch.Size([1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight: torch.Size([1280, 1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_k.weight: torch.Size([1280, 1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_v.weight: torch.Size([1280, 1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.weight: torch.Size([1280, 1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.bias: torch.Size([1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe: torch.Size([1, 32, 1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_q.weight: torch.Size([1280, 1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_k.weight: torch.Size([1280, 1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_v.weight: torch.Size([1280, 1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.weight: torch.Size([1280, 1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.bias: torch.Size([1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe: torch.Size([1, 32, 1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.0.weight: torch.Size([1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.0.bias: torch.Size([1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.1.weight: torch.Size([1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.1.bias: torch.Size([1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.0.proj.weight: torch.Size([10240, 1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.0.proj.bias: torch.Size([10240]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.2.weight: torch.Size([1280, 5120]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.2.bias: torch.Size([1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.ff_norm.weight: torch.Size([1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.ff_norm.bias: torch.Size([1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.proj_out.weight: torch.Size([1280, 1280]) +unet.down_blocks.2.motion_modules.0.temporal_transformer.proj_out.bias: torch.Size([1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.norm.weight: torch.Size([1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.norm.bias: torch.Size([1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.proj_in.weight: torch.Size([1280, 1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.proj_in.bias: torch.Size([1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight: torch.Size([1280, 1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_k.weight: torch.Size([1280, 1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_v.weight: torch.Size([1280, 1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.weight: torch.Size([1280, 1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.bias: torch.Size([1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe: torch.Size([1, 32, 1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_q.weight: torch.Size([1280, 1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_k.weight: torch.Size([1280, 1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_v.weight: torch.Size([1280, 1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.weight: torch.Size([1280, 1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.bias: torch.Size([1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe: torch.Size([1, 32, 1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.0.weight: torch.Size([1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.0.bias: torch.Size([1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.1.weight: torch.Size([1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.1.bias: torch.Size([1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.0.proj.weight: torch.Size([10240, 1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.0.proj.bias: torch.Size([10240]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.2.weight: torch.Size([1280, 5120]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.2.bias: torch.Size([1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.ff_norm.weight: torch.Size([1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.ff_norm.bias: torch.Size([1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.proj_out.weight: torch.Size([1280, 1280]) +unet.down_blocks.2.motion_modules.1.temporal_transformer.proj_out.bias: torch.Size([1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.norm.weight: torch.Size([1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.norm.bias: torch.Size([1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.proj_in.weight: torch.Size([1280, 1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.proj_in.bias: torch.Size([1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight: torch.Size([1280, 1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_k.weight: torch.Size([1280, 1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_v.weight: torch.Size([1280, 1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.weight: torch.Size([1280, 1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.bias: torch.Size([1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe: torch.Size([1, 32, 1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_q.weight: torch.Size([1280, 1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_k.weight: torch.Size([1280, 1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_v.weight: torch.Size([1280, 1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.weight: torch.Size([1280, 1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.bias: torch.Size([1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe: torch.Size([1, 32, 1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.0.weight: torch.Size([1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.0.bias: torch.Size([1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.1.weight: torch.Size([1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.1.bias: torch.Size([1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.0.proj.weight: torch.Size([10240, 1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.0.proj.bias: torch.Size([10240]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.2.weight: torch.Size([1280, 5120]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.2.bias: torch.Size([1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.ff_norm.weight: torch.Size([1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.ff_norm.bias: torch.Size([1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.proj_out.weight: torch.Size([1280, 1280]) +unet.down_blocks.3.motion_modules.0.temporal_transformer.proj_out.bias: torch.Size([1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.norm.weight: torch.Size([1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.norm.bias: torch.Size([1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.proj_in.weight: torch.Size([1280, 1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.proj_in.bias: torch.Size([1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight: torch.Size([1280, 1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_k.weight: torch.Size([1280, 1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_v.weight: torch.Size([1280, 1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.weight: torch.Size([1280, 1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.bias: torch.Size([1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe: torch.Size([1, 32, 1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_q.weight: torch.Size([1280, 1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_k.weight: torch.Size([1280, 1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_v.weight: torch.Size([1280, 1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.weight: torch.Size([1280, 1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.bias: torch.Size([1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe: torch.Size([1, 32, 1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.0.weight: torch.Size([1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.0.bias: torch.Size([1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.1.weight: torch.Size([1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.1.bias: torch.Size([1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.0.proj.weight: torch.Size([10240, 1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.0.proj.bias: torch.Size([10240]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.2.weight: torch.Size([1280, 5120]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.2.bias: torch.Size([1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.ff_norm.weight: torch.Size([1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.ff_norm.bias: torch.Size([1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.proj_out.weight: torch.Size([1280, 1280]) +unet.down_blocks.3.motion_modules.1.temporal_transformer.proj_out.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.norm.weight: torch.Size([1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.norm.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.proj_in.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.proj_in.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_k.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_v.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe: torch.Size([1, 32, 1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_q.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_k.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_v.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe: torch.Size([1, 32, 1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.0.weight: torch.Size([1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.0.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.1.weight: torch.Size([1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.1.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.0.proj.weight: torch.Size([10240, 1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.0.proj.bias: torch.Size([10240]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.2.weight: torch.Size([1280, 5120]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.2.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.ff_norm.weight: torch.Size([1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.ff_norm.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.proj_out.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.0.temporal_transformer.proj_out.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.norm.weight: torch.Size([1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.norm.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.proj_in.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.proj_in.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_k.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_v.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe: torch.Size([1, 32, 1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_q.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_k.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_v.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe: torch.Size([1, 32, 1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.0.weight: torch.Size([1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.0.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.1.weight: torch.Size([1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.1.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.0.proj.weight: torch.Size([10240, 1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.0.proj.bias: torch.Size([10240]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.2.weight: torch.Size([1280, 5120]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.2.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.ff_norm.weight: torch.Size([1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.transformer_blocks.0.ff_norm.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.proj_out.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.1.temporal_transformer.proj_out.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.norm.weight: torch.Size([1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.norm.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.proj_in.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.proj_in.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_k.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_v.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe: torch.Size([1, 32, 1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_q.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_k.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_v.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe: torch.Size([1, 32, 1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.norms.0.weight: torch.Size([1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.norms.0.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.norms.1.weight: torch.Size([1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.norms.1.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.ff.net.0.proj.weight: torch.Size([10240, 1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.ff.net.0.proj.bias: torch.Size([10240]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.ff.net.2.weight: torch.Size([1280, 5120]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.ff.net.2.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.weight: torch.Size([1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.bias: torch.Size([1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.proj_out.weight: torch.Size([1280, 1280]) +unet.up_blocks.0.motion_modules.2.temporal_transformer.proj_out.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.norm.weight: torch.Size([1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.norm.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.proj_in.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.proj_in.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_k.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_v.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe: torch.Size([1, 32, 1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_q.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_k.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_v.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe: torch.Size([1, 32, 1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.0.weight: torch.Size([1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.0.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.1.weight: torch.Size([1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.1.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.0.proj.weight: torch.Size([10240, 1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.0.proj.bias: torch.Size([10240]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.2.weight: torch.Size([1280, 5120]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.2.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.ff_norm.weight: torch.Size([1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.transformer_blocks.0.ff_norm.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.proj_out.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.0.temporal_transformer.proj_out.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.norm.weight: torch.Size([1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.norm.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.proj_in.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.proj_in.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_k.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_v.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe: torch.Size([1, 32, 1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_q.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_k.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_v.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe: torch.Size([1, 32, 1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.0.weight: torch.Size([1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.0.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.1.weight: torch.Size([1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.1.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.0.proj.weight: torch.Size([10240, 1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.0.proj.bias: torch.Size([10240]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.2.weight: torch.Size([1280, 5120]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.2.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.ff_norm.weight: torch.Size([1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.transformer_blocks.0.ff_norm.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.proj_out.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.1.temporal_transformer.proj_out.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.norm.weight: torch.Size([1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.norm.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.proj_in.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.proj_in.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_k.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_v.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe: torch.Size([1, 32, 1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_q.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_k.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_v.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe: torch.Size([1, 32, 1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.norms.0.weight: torch.Size([1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.norms.0.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.norms.1.weight: torch.Size([1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.norms.1.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.ff.net.0.proj.weight: torch.Size([10240, 1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.ff.net.0.proj.bias: torch.Size([10240]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.ff.net.2.weight: torch.Size([1280, 5120]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.ff.net.2.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.weight: torch.Size([1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.bias: torch.Size([1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.proj_out.weight: torch.Size([1280, 1280]) +unet.up_blocks.1.motion_modules.2.temporal_transformer.proj_out.bias: torch.Size([1280]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.norm.weight: torch.Size([640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.norm.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.proj_in.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.proj_in.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_k.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_v.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe: torch.Size([1, 32, 640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_q.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_k.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_v.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe: torch.Size([1, 32, 640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.0.weight: torch.Size([640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.0.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.1.weight: torch.Size([640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.1.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.0.proj.weight: torch.Size([5120, 640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.0.proj.bias: torch.Size([5120]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.2.weight: torch.Size([640, 2560]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.2.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.ff_norm.weight: torch.Size([640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.transformer_blocks.0.ff_norm.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.proj_out.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.0.temporal_transformer.proj_out.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.norm.weight: torch.Size([640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.norm.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.proj_in.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.proj_in.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_k.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_v.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe: torch.Size([1, 32, 640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_q.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_k.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_v.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe: torch.Size([1, 32, 640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.0.weight: torch.Size([640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.0.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.1.weight: torch.Size([640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.1.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.0.proj.weight: torch.Size([5120, 640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.0.proj.bias: torch.Size([5120]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.2.weight: torch.Size([640, 2560]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.2.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.ff_norm.weight: torch.Size([640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.transformer_blocks.0.ff_norm.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.proj_out.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.1.temporal_transformer.proj_out.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.norm.weight: torch.Size([640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.norm.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.proj_in.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.proj_in.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_k.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_v.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe: torch.Size([1, 32, 640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_q.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_k.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_v.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe: torch.Size([1, 32, 640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.norms.0.weight: torch.Size([640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.norms.0.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.norms.1.weight: torch.Size([640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.norms.1.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff.net.0.proj.weight: torch.Size([5120, 640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff.net.0.proj.bias: torch.Size([5120]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff.net.2.weight: torch.Size([640, 2560]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff.net.2.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.weight: torch.Size([640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.bias: torch.Size([640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.proj_out.weight: torch.Size([640, 640]) +unet.up_blocks.2.motion_modules.2.temporal_transformer.proj_out.bias: torch.Size([640]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.norm.weight: torch.Size([320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.norm.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.proj_in.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.proj_in.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_k.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_v.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe: torch.Size([1, 32, 320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_q.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_k.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_v.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe: torch.Size([1, 32, 320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.0.weight: torch.Size([320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.0.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.1.weight: torch.Size([320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.norms.1.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.0.proj.weight: torch.Size([2560, 320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.0.proj.bias: torch.Size([2560]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.2.weight: torch.Size([320, 1280]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.ff.net.2.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.ff_norm.weight: torch.Size([320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.transformer_blocks.0.ff_norm.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.proj_out.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.0.temporal_transformer.proj_out.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.norm.weight: torch.Size([320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.norm.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.proj_in.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.proj_in.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_k.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_v.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe: torch.Size([1, 32, 320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_q.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_k.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_v.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe: torch.Size([1, 32, 320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.0.weight: torch.Size([320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.0.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.1.weight: torch.Size([320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.norms.1.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.0.proj.weight: torch.Size([2560, 320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.0.proj.bias: torch.Size([2560]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.2.weight: torch.Size([320, 1280]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.ff.net.2.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.ff_norm.weight: torch.Size([320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.transformer_blocks.0.ff_norm.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.proj_out.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.1.temporal_transformer.proj_out.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.norm.weight: torch.Size([320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.norm.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.proj_in.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.proj_in.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_q.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_k.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_v.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.to_out.0.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe: torch.Size([1, 32, 320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_q.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_k.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_v.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.to_out.0.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe: torch.Size([1, 32, 320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.norms.0.weight: torch.Size([320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.norms.0.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.norms.1.weight: torch.Size([320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.norms.1.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.ff.net.0.proj.weight: torch.Size([2560, 320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.ff.net.0.proj.bias: torch.Size([2560]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.ff.net.2.weight: torch.Size([320, 1280]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.ff.net.2.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.weight: torch.Size([320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.transformer_blocks.0.ff_norm.bias: torch.Size([320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.proj_out.weight: torch.Size([320, 320]) +unet.up_blocks.3.motion_modules.2.temporal_transformer.proj_out.bias: torch.Size([320]) diff --git a/misc_utils/__pycache__/flow_utils.cpython-310.pyc b/misc_utils/__pycache__/flow_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a37d5bc9def24eda110060da7425d838a554227 Binary files /dev/null and b/misc_utils/__pycache__/flow_utils.cpython-310.pyc differ diff --git a/misc_utils/__pycache__/image_utils.cpython-310.pyc b/misc_utils/__pycache__/image_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ce78d709b29fc0ab396af2c80a8353c896ab9b2 Binary files /dev/null and b/misc_utils/__pycache__/image_utils.cpython-310.pyc differ diff --git a/misc_utils/__pycache__/image_utils.cpython-38.pyc b/misc_utils/__pycache__/image_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b162a6047541e7524b6be99eb0fb449d279adb15 Binary files /dev/null and b/misc_utils/__pycache__/image_utils.cpython-38.pyc differ diff --git a/misc_utils/__pycache__/model_utils.cpython-310.pyc b/misc_utils/__pycache__/model_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f00793e4719ef55fc5e443ac95c1d38ea1065c8 Binary files /dev/null and b/misc_utils/__pycache__/model_utils.cpython-310.pyc differ diff --git a/misc_utils/__pycache__/model_utils.cpython-38.pyc b/misc_utils/__pycache__/model_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbe8ffc9acdca0754ff046231285bb105a28864c Binary files /dev/null and b/misc_utils/__pycache__/model_utils.cpython-38.pyc differ diff --git a/misc_utils/__pycache__/train_utils.cpython-310.pyc b/misc_utils/__pycache__/train_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce67b384c02fa1e5bd14ae47381ded4937ad1ff2 Binary files /dev/null and b/misc_utils/__pycache__/train_utils.cpython-310.pyc differ diff --git a/misc_utils/__pycache__/train_utils.cpython-38.pyc b/misc_utils/__pycache__/train_utils.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ceda99e6bf224fb0b739076887b609a0f1f1f195 Binary files /dev/null and b/misc_utils/__pycache__/train_utils.cpython-38.pyc differ diff --git a/misc_utils/__pycache__/train_utils.cpython-39.pyc b/misc_utils/__pycache__/train_utils.cpython-39.pyc new file mode 100644 index 0000000000000000000000000000000000000000..641e18e5137161c3bae7cc3a27448c2a267b5b29 Binary files /dev/null and b/misc_utils/__pycache__/train_utils.cpython-39.pyc differ diff --git a/misc_utils/clip_similarity.py b/misc_utils/clip_similarity.py new file mode 100644 index 0000000000000000000000000000000000000000..577b9726e155bf1322171ed13ba0d80d5aa80a7b --- /dev/null +++ b/misc_utils/clip_similarity.py @@ -0,0 +1,47 @@ +# from https://github.com/timothybrooks/instruct-pix2pix/blob/main/metrics/clip_similarity.py + +import clip +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +class ClipSimilarity(nn.Module): + def __init__(self, name: str = "ViT-L/14"): + super().__init__() + assert name in ("RN50", "RN101", "RN50x4", "RN50x16", "RN50x64", "ViT-B/32", "ViT-B/16", "ViT-L/14", "ViT-L/14@336px") # fmt: skip + self.size = {"RN50x4": 288, "RN50x16": 384, "RN50x64": 448, "ViT-L/14@336px": 336}.get(name, 224) + + self.model, _ = clip.load(name, device="cpu", download_root="./") + self.model.eval().requires_grad_(False) + + self.register_buffer("mean", torch.tensor((0.48145466, 0.4578275, 0.40821073))) + self.register_buffer("std", torch.tensor((0.26862954, 0.26130258, 0.27577711))) + + def encode_text(self, text: list[str]) -> torch.Tensor: + text = clip.tokenize(text, truncate=True).to(next(self.parameters()).device) + text_features = self.model.encode_text(text) + text_features = text_features / text_features.norm(dim=1, keepdim=True) + return text_features + + def encode_image(self, image: torch.Tensor) -> torch.Tensor: # Input images in range [0, 1]. + image = F.interpolate(image.float(), size=self.size, mode="bicubic", align_corners=False) + image = image - rearrange(self.mean, "c -> 1 c 1 1") + image = image / rearrange(self.std, "c -> 1 c 1 1") + image_features = self.model.encode_image(image) + image_features = image_features / image_features.norm(dim=1, keepdim=True) + return image_features + + def forward( + self, image_0: torch.Tensor, image_1: torch.Tensor, text_0: list[str], text_1: list[str] + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + image_features_0 = self.encode_image(image_0) + image_features_1 = self.encode_image(image_1) + text_features_0 = self.encode_text(text_0) + text_features_1 = self.encode_text(text_1) + sim_0 = F.cosine_similarity(image_features_0, text_features_0) + sim_1 = F.cosine_similarity(image_features_1, text_features_1) + sim_direction = F.cosine_similarity(image_features_1 - image_features_0, text_features_1 - text_features_0) + sim_image = F.cosine_similarity(image_features_0, image_features_1) + return sim_0, sim_1, sim_direction, sim_image \ No newline at end of file diff --git a/misc_utils/flow_utils.py b/misc_utils/flow_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..069a1287d2e1329ee40be0ba72f3dfe1d02b7f08 --- /dev/null +++ b/misc_utils/flow_utils.py @@ -0,0 +1,189 @@ +''' +Usage: + +from misc_utils.flow_utils import RAFTFlow, load_image_as_tensor, warp_image, MyRandomPerspective, generate_sample +image = load_image_as_tensor('hamburger_pic.jpeg', image_size) +flow_estimator = RAFTFlow() +res = generate_sample( + image, + flow_estimator, + distortion_scale=distortion_scale, +) +f1 = res['input'][None] +f2 = res['target'][None] +flow = res['flow'][None] +f1_warp = warp_image(f1, flow) +show_image(f1_warp[0]) +show_image(f2[0]) +''' +import torch +import torch.nn.functional as F +import torchvision.transforms.functional as TF +from torchvision.models.optical_flow import raft_large, Raft_Large_Weights +import numpy as np + +def warp_image(image, flow, mode='bilinear'): + """ Warp an image using optical flow. + Args: + image (torch.Tensor): Input image tensor with shape (N, C, H, W). + flow (torch.Tensor): Optical flow tensor with shape (N, 2, H, W). + Returns: + warped_image (torch.Tensor): Warped image tensor with shape (N, C, H, W). + """ + # check shape + if len(image.shape) == 3: + image = image.unsqueeze(0) + if len(flow.shape) == 3: + flow = flow.unsqueeze(0) + if image.device != flow.device: + flow = flow.to(image.device) + assert image.shape[0] == flow.shape[0], f'Batch size of image and flow must be the same. Got {image.shape[0]} and {flow.shape[0]}.' + assert image.shape[2:] == flow.shape[2:], f'Height and width of image and flow must be the same. Got {image.shape[2:]} and {flow.shape[2:]}.' + # Generate a grid of sampling points + grid = torch.tensor( + np.array(np.meshgrid(range(image.shape[3]), range(image.shape[2]), indexing='xy')), + dtype=torch.float32, device=image.device + )[None] + grid = grid.permute(0, 2, 3, 1).repeat(image.shape[0], 1, 1, 1) # (N, H, W, 2) + grid += flow.permute(0, 2, 3, 1) # add optical flow to grid + + # Normalize grid to [-1, 1] + grid[:, :, :, 0] = 2 * (grid[:, :, :, 0] / (image.shape[3] - 1) - 0.5) + grid[:, :, :, 1] = 2 * (grid[:, :, :, 1] / (image.shape[2] - 1) - 0.5) + + # Sample input image using the grid + warped_image = F.grid_sample(image, grid, mode=mode, align_corners=True) + + return warped_image + +def resize_flow(flow, size): + """ + Resize optical flow tensor to a new size. + + Args: + flow (torch.Tensor): Optical flow tensor with shape (B, 2, H, W). + size (tuple[int, int]): Target size as a tuple (H, W). + + Returns: + flow_resized (torch.Tensor): Resized optical flow tensor with shape (B, 2, H, W). + """ + # Unpack the target size + H, W = size + + # Compute the scaling factors + h, w = flow.shape[2:] + scale_x = W / w + scale_y = H / h + + # Scale the optical flow by the resizing factors + flow_scaled = flow.clone() + flow_scaled[:, 0] *= scale_x + flow_scaled[:, 1] *= scale_y + + # Resize the optical flow to the new size (H, W) + flow_resized = F.interpolate(flow_scaled, size=(H, W), mode='bilinear', align_corners=False) + + return flow_resized + +def check_consistency(flow1: torch.Tensor, flow2: torch.Tensor) -> torch.Tensor: + """ + Check the consistency of two optical flows. + flow1: (B, 2, H, W) + flow2: (B, 2, H, W) + if want the output to be forward flow, then flow1 is the forward flow and flow2 is the backward flow + return: (H, W) + """ + device = flow1.device + height, width = flow1.shape[2:] + + kernel_x = torch.tensor([[0.5, 0, -0.5]]).unsqueeze(0).unsqueeze(0).to(device) + kernel_y = torch.tensor([[0.5], [0], [-0.5]]).unsqueeze(0).unsqueeze(0).to(device) + grad_x = torch.nn.functional.conv2d(flow1[:, :1], kernel_x, padding=(0, 1)) + grad_y = torch.nn.functional.conv2d(flow1[:, 1:], kernel_y, padding=(1, 0)) + + motion_edge = (grad_x * grad_x + grad_y * grad_y).sum(dim=1).squeeze(0) + + ax, ay = torch.meshgrid(torch.arange(width, device=device), torch.arange(height, device=device), indexing='xy') + bx, by = ax + flow1[:, 0], ay + flow1[:, 1] + + x1, y1 = torch.floor(bx).long(), torch.floor(by).long() + x2, y2 = x1 + 1, y1 + 1 + x1 = torch.clamp(x1, 0, width - 1) + x2 = torch.clamp(x2, 0, width - 1) + y1 = torch.clamp(y1, 0, height - 1) + y2 = torch.clamp(y2, 0, height - 1) + + alpha_x, alpha_y = bx - x1.float(), by - y1.float() + + a = (1.0 - alpha_x) * flow2[:, 0, y1, x1] + alpha_x * flow2[:, 0, y1, x2] + b = (1.0 - alpha_x) * flow2[:, 0, y2, x1] + alpha_x * flow2[:, 0, y2, x2] + u = (1.0 - alpha_y) * a + alpha_y * b + + a = (1.0 - alpha_x) * flow2[:, 1, y1, x1] + alpha_x * flow2[:, 1, y1, x2] + b = (1.0 - alpha_x) * flow2[:, 1, y2, x1] + alpha_x * flow2[:, 1, y2, x2] + v = (1.0 - alpha_y) * a + alpha_y * b + + cx, cy = bx + u, by + v + u2, v2 = flow1[:, 0], flow1[:, 1] + + reliable = ((((cx - ax) ** 2 + (cy - ay) ** 2) < (0.01 * (u2 ** 2 + v2 ** 2 + u ** 2 + v ** 2) + 0.5)) & (motion_edge <= 0.01 * (u2 ** 2 + v2 ** 2) + 0.002)).float() + + return reliable # (B, 1, H, W) + + +class RAFTFlow(torch.nn.Module): + ''' + # Instantiate the RAFTFlow class + raft_flow = RAFTFlow(device='cuda') + + # Load a pair of image frames as PyTorch tensors + img1 = torch.tensor(np.random.rand(3, 720, 1280), dtype=torch.float32) + img2 = torch.tensor(np.random.rand(3, 720, 1280), dtype=torch.float32) + + # Compute optical flow between the two frames + (optional) image_size = (256, 256) or None + flow = raft_flow.compute_flow(img1, img2, image_size) # flow will be computed at the original image size if image_size is None + # this flow can be used to warp the second image to the first image + + # Warp the second image using the flow + warped_img = warp_image(img2, flow) + ''' + def __init__(self, *args): + """ + Args: + device (str): Device to run the model on ("cpu" or "cuda"). + """ + super().__init__(*args) + weights = Raft_Large_Weights.DEFAULT + self.model = raft_large(weights=weights, progress=False) + self.model_transform = weights.transforms() + + def forward(self, img1, img2, img_size=None): + """ + Compute optical flow between two frames using RAFT model. + + Args: + img1 (torch.Tensor): First frame tensor with shape (B, C, H, W). + img2 (torch.Tensor): Second frame tensor with shape (B, C, H, W). + img_size (tuple): Size of the input images to be processed. + + Returns: + flow (torch.Tensor): Optical flow tensor with shape (B, 2, H, W). + """ + original_size = img1.shape[2:] + # Preprocess the input frames + if img_size is not None: + img1 = TF.resize(img1, size=img_size, antialias=False) + img2 = TF.resize(img2, size=img_size, antialias=False) + + img1, img2 = self.model_transform(img1, img2) + + # Compute the optical flow using the RAFT model + with torch.no_grad(): + list_of_flows = self.model(img1, img2) + flow = list_of_flows[-1] + + if img_size is not None: + flow = resize_flow(flow, original_size) + + return flow diff --git a/misc_utils/image_utils.py b/misc_utils/image_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7284f8d2d0baefdfe5166ecd6c0820f935206a --- /dev/null +++ b/misc_utils/image_utils.py @@ -0,0 +1,241 @@ +import os +import matplotlib.pyplot as plt +import torch +import numpy as np +import cv2 +import imageio +from PIL import Image +import textwrap + +def find_nearest_Nx(size, N=32): + return int(np.ceil(size / N) * N) + +def load_image_as_tensor(image_path, image_size): + if isinstance(image_size, int): + image_size = (image_size, image_size) + image = cv2.imread(image_path)[..., ::-1] + try: + image = cv2.resize(image, image_size) + except Exception as e: + print(e) + print(image_path) + + image = torch.from_numpy(np.array(image).transpose(2, 0, 1)) / 255. + return image + +def show_image(image): + if len(image.shape) == 4: + image = image[0] + plt.imshow(image.permute(1, 2, 0).detach().cpu().numpy()) + plt.show() + +def extract_video(video_path, save_dir, sampling_fps, skip_frames=0): + os.makedirs(save_dir, exist_ok=True) + cap = cv2.VideoCapture(video_path) + frame_skip = int(cap.get(cv2.CAP_PROP_FPS) / sampling_fps) + frame_count = 0 + save_count = 0 + while True: + ret, frame = cap.read() + if not ret: + break + if frame_count < skip_frames: # skip the first N frames + frame_count += 1 + continue + if (frame_count - skip_frames) % frame_skip == 0: + # Save the frame as an image file if it doesn't already exist + save_path = os.path.join(save_dir, f"frame{save_count:04d}.jpg") + save_count += 1 + if not os.path.exists(save_path): + cv2.imwrite(save_path, frame) + frame_count += 1 + cap.release() + cv2.destroyAllWindows() + +def concatenate_frames_to_video(frame_dir, video_path, fps): + os.makedirs(os.path.dirname(video_path), exist_ok=True) + # Get the list of frame file names in the directory + frame_files = [f for f in os.listdir(frame_dir) if f.startswith("frame")] + # Sort the frame file names in ascending order + frame_files.sort() + # Load the first frame to get the frame size + frame = cv2.imread(os.path.join(frame_dir, frame_files[0])) + height, width, _ = frame.shape + # Initialize the video writer + fourcc = cv2.VideoWriter_fourcc(*"mp4v") + out = cv2.VideoWriter(video_path, fourcc, fps, (width, height)) + # Loop through the frame files and add them to the video + for frame_file in frame_files: + frame_path = os.path.join(frame_dir, frame_file) + frame = cv2.imread(frame_path) + out.write(frame) + # Release the video writer + out.release() + +def cumulative_histogram(hist): + cum_hist = hist.copy() + for i in range(1, len(hist)): + cum_hist[i] = cum_hist[i - 1] + hist[i] + return cum_hist + +def histogram_matching(src_img, ref_img): + src_img = (src_img * 255).astype(np.uint8) + ref_img = (ref_img * 255).astype(np.uint8) + src_img_yuv = cv2.cvtColor(src_img, cv2.COLOR_RGB2YUV) + ref_img_yuv = cv2.cvtColor(ref_img, cv2.COLOR_RGB2YUV) + + matched_img = np.zeros_like(src_img_yuv) + for channel in range(src_img_yuv.shape[2]): + src_hist, _ = np.histogram(src_img_yuv[:, :, channel].ravel(), 256, (0, 256)) + ref_hist, _ = np.histogram(ref_img_yuv[:, :, channel].ravel(), 256, (0, 256)) + + src_cum_hist = cumulative_histogram(src_hist) + ref_cum_hist = cumulative_histogram(ref_hist) + + lut = np.zeros(256, dtype=np.uint8) + j = 0 + for i in range(256): + while ref_cum_hist[j] < src_cum_hist[i] and j < 255: + j += 1 + lut[i] = j + + matched_img[:, :, channel] = cv2.LUT(src_img_yuv[:, :, channel], lut) + + matched_img = cv2.cvtColor(matched_img, cv2.COLOR_YUV2RGB) + matched_img = matched_img.astype(np.float32) / 255 + return matched_img + +def canny_image_batch(image_batch, low_threshold=100, high_threshold=200): + if isinstance(image_batch, torch.Tensor): + # [-1, 1] tensor -> [0, 255] numpy array + is_torch = True + device = image_batch.device + image_batch = (image_batch + 1) * 127.5 + image_batch = image_batch.permute(0, 2, 3, 1).detach().cpu().numpy() + image_batch = image_batch.astype(np.uint8) + image_batch = np.array([cv2.Canny(image, low_threshold, high_threshold) for image in image_batch]) + image_batch = image_batch[:, :, :, None] + image_batch = np.concatenate([image_batch, image_batch, image_batch], axis=3) + + if is_torch: + # [0, 255] numpy array -> [-1, 1] tensor + image_batch = torch.from_numpy(image_batch).permute(0, 3, 1, 2).float() / 255. + image_batch = image_batch.to(device) + return image_batch + + +def images_to_gif(images, filename, fps): + os.makedirs(os.path.dirname(filename), exist_ok=True) + # Normalize to 0-255 and convert to uint8 + images = [(img * 255).astype(np.uint8) if img.dtype == np.float32 else img for img in images] + images = [Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) for img in images] + imageio.mimsave(filename, images, duration=1 / fps) + +def load_gif(image_path): + import imageio + gif = imageio.get_reader(image_path) + np_images = np.array([frame[..., :3] for frame in gif]) + return np_images + +def add_text_to_frame(frame, text, font_scale=1, thickness=2, color=(0, 0, 0), bg_color=(255, 255, 255), max_width=30): + """ + Add text to a frame. + """ + # Make a copy of the frame + frame_with_text = np.copy(frame) + # Choose font + font = cv2.FONT_HERSHEY_SIMPLEX + # Split text into lines if it's too long + lines = textwrap.wrap(text, width=max_width) + # Get total text height + total_text_height = len(lines) * (thickness * font_scale + 10) + 60 * font_scale + # Create an image filled with the background color, having enough space for the text + text_bg_img = np.full((int(total_text_height), frame.shape[1], 3), bg_color, dtype=np.uint8) + # Put each line on the text background image + y = 0 + for line in lines: + text_size, _ = cv2.getTextSize(line, font, font_scale, thickness) + text_x = (text_bg_img.shape[1] - text_size[0]) // 2 + y += text_size[1] + 10 + cv2.putText(text_bg_img, line, (text_x, y), font, font_scale, color, thickness) + # Append the text background image to the frame + frame_with_text = np.vstack((frame_with_text, text_bg_img)) + + return frame_with_text + +def add_text_to_gif(numpy_images, text, **kwargs): + """ + Add text to each frame of a gif. + """ + # Iterate over frames and add text to each frame + frames_with_text = [] + for frame in numpy_images: + frame_with_text = add_text_to_frame(frame, text, **kwargs) + frames_with_text.append(frame_with_text) + + # Convert the list of frames to a numpy array + numpy_images_with_text = np.array(frames_with_text) + + return numpy_images_with_text + +def pad_images_to_same_height(images): + """ + Pad images to the same height. + """ + # Find the maximum height + max_height = max(img.shape[0] for img in images) + + # Pad each image to the maximum height + padded_images = [] + for img in images: + pad_height = max_height - img.shape[0] + padded_img = cv2.copyMakeBorder(img, 0, pad_height, 0, 0, cv2.BORDER_CONSTANT, value=[255, 255, 255]) + padded_images.append(padded_img) + + return padded_images + +def concatenate_gifs(gifs): + """ + Concatenate gifs. + """ + # Ensure that all gifs have the same number of frames + min_num_frames = min(gif.shape[0] for gif in gifs) + gifs = [gif[:min_num_frames] for gif in gifs] + + # Concatenate each frame + concatenated_gifs = [] + for i in range(min_num_frames): + # Get the i-th frame from each gif + frames = [gif[i] for gif in gifs] + + # Pad the frames to the same height + padded_frames = pad_images_to_same_height(frames) + + # Concatenate the padded frames + concatenated_frame = np.concatenate(padded_frames, axis=1) + + concatenated_gifs.append(concatenated_frame) + + return np.array(concatenated_gifs) + +def stack_gifs(gifs): + '''vertically stack gifs''' + min_num_frames = min(gif.shape[0] for gif in gifs) + stacked_gifs = [] + + for i in range(min_num_frames): + frames = [gif[i] for gif in gifs] + stacked_frame = np.concatenate(frames, axis=0) + stacked_gifs.append(stacked_frame) + + return np.array(stacked_gifs) + +def save_tensor_to_gif(images, filename, fps): + images = images.squeeze(0).detach().cpu().numpy().transpose(0, 2, 3, 1) / 2 + 0.5 + images_to_gif(images, filename, fps) + +def save_tensor_to_images(images, output_dir): + images = images.squeeze(0).detach().cpu().numpy().transpose(0, 2, 3, 1) / 2 + 0.5 + os.makedirs(output_dir, exist_ok=True) + for i in range(images.shape[0]): + plt.imsave(f'{output_dir}/{i:03d}.jpg', images[i]) \ No newline at end of file diff --git a/misc_utils/model_utils.py b/misc_utils/model_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2093575f324765fa45eac7f9bb82e7863c418fba --- /dev/null +++ b/misc_utils/model_utils.py @@ -0,0 +1,115 @@ +import importlib +import torch +import numpy as np +from inspect import isfunction + +def instantiate_from_config(config): + if not "target" in config: + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + +def exists(x): + return x is not None + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + +def right_pad_dims_to(x, t): + padding_dims = x.ndim - t.ndim + if padding_dims <= 0: + return t + return t.view(*t.shape, *((1,) * padding_dims)) + +def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): + if schedule == "linear" or schedule == "scaled_linear": + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + elif schedule == "sqrt": + betas = torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) ** 0.5 + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True): + if ddim_discr_method == 'uniform': + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == 'quad': + ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + else: + raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f'Selected timesteps for ddim sampler: {steps_out}') + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt((1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev)) + if verbose: + print(f'Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}') + print(f'For the chosen value of eta, which is {eta}, ' + f'this results in the following sigma_t schedule for ddim sampler {sigmas}') + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) \ No newline at end of file diff --git a/misc_utils/ptp_utils.py b/misc_utils/ptp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6eb0dabd07f4ffb91ed69c5b77d2f453840bc94a --- /dev/null +++ b/misc_utils/ptp_utils.py @@ -0,0 +1,181 @@ +from dataclasses import dataclass +import torch +import numpy as np + +@dataclass +class Edit: + old: str + new: str + weight: float = 1.0 + + +@dataclass +class Insert: + text: str + weight: float = 1.0 + + @property + def old(self): + return "" + + @property + def new(self): + return self.text + + +@dataclass +class Delete: + text: str + weight: float = 1.0 + + @property + def old(self): + return self.text + + @property + def new(self): + return "" + + +@dataclass +class Text: + text: str + weight: float = 1.0 + + @property + def old(self): + return self.text + + @property + def new(self): + return self.text + +@torch.inference_mode() +def get_text_embedding(prompt, tokenizer, text_encoder): + text_input_ids = tokenizer( + prompt, + padding="max_length", + truncation=True, + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + text_embeddings = text_encoder(text_input_ids.to(text_encoder.device))[0] + return text_embeddings + +@torch.inference_mode() +def encode_text(text_pieces, tokenizer, text_encoder): + n_old_tokens = 0 + n_new_tokens = 0 + new_id_to_old_id = [] + weights = [] + for piece in text_pieces: + old, new = piece.old, piece.new + old_tokens = tokenizer.tokenize(old) + new_tokens = tokenizer.tokenize(new) + if len(old_tokens) == 0 and len(new_tokens) == 0: + continue + elif old == new: + n_old_tokens += len(old_tokens) + n_new_tokens += len(new_tokens) + new_id_to_old_id.extend(range(n_old_tokens - len(old_tokens), n_old_tokens)) + elif len(old_tokens) == 0: + # insert + new_id_to_old_id.extend([-1] * len(new_tokens)) + n_new_tokens += len(new_tokens) + elif len(new_tokens) == 0: + # delete + n_old_tokens += len(old_tokens) + else: + # replace + n_old_tokens += len(old_tokens) + n_new_tokens += len(new_tokens) + start = n_old_tokens - len(old_tokens) + end = n_old_tokens + ids = np.linspace(start, end, len(new_tokens), endpoint=False).astype(int) + new_id_to_old_id.extend(list(ids)) + weights.extend([piece.weight] * len(new_tokens)) + + old_prompt = " ".join([piece.old for piece in text_pieces]) + new_prompt = " ".join([piece.new for piece in text_pieces]) + old_text_input_ids = tokenizer( + old_prompt, + padding="max_length", + truncation=True, + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + new_text_input_ids = tokenizer( + new_prompt, + padding="max_length", + truncation=True, + max_length=tokenizer.model_max_length, + return_tensors="pt", + ).input_ids + + old_text_embeddings = text_encoder(old_text_input_ids.to(text_encoder.device))[0] + new_text_embeddings = text_encoder(new_text_input_ids.to(text_encoder.device))[0] + value = new_text_embeddings.clone() # batch (1), seq, dim + key = new_text_embeddings.clone() + + for i, (j, weight) in enumerate(zip(new_id_to_old_id, weights)): + if 0 <= j < old_text_embeddings.shape[1]: + key[0, i] = old_text_embeddings[0, j] + value[0, i] *= weight + return key, value + +@torch.inference_mode() +def get_text_embedding_openclip(prompt, text_encoder, device='cuda'): + import open_clip + text_input_ids = open_clip.tokenize(prompt) + text_embeddings = text_encoder(text_input_ids.to(device)) + return text_embeddings + +@torch.inference_mode() +def encode_text_openclip(text_pieces, text_encoder, device='cuda'): + import open_clip + n_old_tokens = 0 + n_new_tokens = 0 + new_id_to_old_id = [] + weights = [] + for piece in text_pieces: + old, new = piece.old, piece.new + old_tokens = open_clip.tokenize(old) + new_tokens = open_clip.tokenize(new) + if len(old_tokens) == 0 and len(new_tokens) == 0: + continue + elif old == new: + n_old_tokens += len(old_tokens) + n_new_tokens += len(new_tokens) + new_id_to_old_id.extend(range(n_old_tokens - len(old_tokens), n_old_tokens)) + elif len(old_tokens) == 0: + # insert + new_id_to_old_id.extend([-1] * len(new_tokens)) + n_new_tokens += len(new_tokens) + elif len(new_tokens) == 0: + # delete + n_old_tokens += len(old_tokens) + else: + # replace + n_old_tokens += len(old_tokens) + n_new_tokens += len(new_tokens) + start = n_old_tokens - len(old_tokens) + end = n_old_tokens + ids = np.linspace(start, end, len(new_tokens), endpoint=False).astype(int) + new_id_to_old_id.extend(list(ids)) + weights.extend([piece.weight] * len(new_tokens)) + + old_prompt = " ".join([piece.old for piece in text_pieces]) + new_prompt = " ".join([piece.new for piece in text_pieces]) + old_text_input_ids = open_clip.tokenize(old_prompt) + new_text_input_ids = open_clip.tokenize(new_prompt) + + old_text_embeddings = text_encoder(old_text_input_ids.to(device)) + new_text_embeddings = text_encoder(new_text_input_ids.to(device)) + value = new_text_embeddings.clone() # batch (1), seq, dim + key = new_text_embeddings.clone() + + for i, (j, weight) in enumerate(zip(new_id_to_old_id, weights)): + if 0 <= j < old_text_embeddings.shape[1]: + key[0, i] = old_text_embeddings[0, j] + value[0, i] *= weight + return key, value \ No newline at end of file diff --git a/misc_utils/train_utils.py b/misc_utils/train_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..afbe929d74f43c71090bac9491a08609c9b396e2 --- /dev/null +++ b/misc_utils/train_utils.py @@ -0,0 +1,176 @@ +import torch +from omegaconf import OmegaConf +from pytorch_lightning.loggers import WandbLogger +from misc_utils.model_utils import instantiate_from_config, get_obj_from_str + +from diffusers import AutoencoderKL +import os +import json + +def get_models(args): + unet = instantiate_from_config(args.unet) + model_dict = { + 'unet': unet, + } + + if args.get('vae'): + vae = instantiate_from_config(args.vae) + model_dict['vae'] = vae + + if args.get('text_model'): + text_model = instantiate_from_config(args.text_model) + model_dict['text_model'] = text_model + + if args.get('ctrlnet'): # 这边还可以加ctrlnet... (感觉是哪个地方搬来的代码) + ctrlnet = instantiate_from_config(args.ctrlnet) + model_dict['ctrlnet'] = ctrlnet + + return model_dict + +def get_text_model(args): + base_path = None + train_net = None + if args.get('diffusion'): + if args.diffusion.params.get('base_path'):# 这边有base path的情况下已经load参数了 + base_path = args.diffusion.params.base_path + train_net = args.diffusion.params.get('unet_init_weights') + if args.get('text_model'): + args.text_model.params.base_path = base_path + text_model = instantiate_from_config(args.text_model) + return text_model + return None + +def get_vae(args): + base_path = None + if args.get('diffusion'): + if args.diffusion.params.get('base_path'): + base_path = args.diffusion.params.base_path + vae = AutoencoderKL.from_pretrained(os.path.join(base_path, "vae")) + return vae + return None + +def get_ic_models(args): + unet = instantiate_from_config(args.unet) + model_dict = { + 'unet': unet, + } + + vae = get_vae(args) # 这边vae是直接diffusers中的组件加载的 + if vae: + model_dict['vae'] = vae + + text_model = get_text_model(args) # text model的话整体没咋变, 主要更改了from_pretrained的来源 + if text_model: + model_dict['text_model'] = text_model + + if args.get('ctrlnet'): # 这边还可以加ctrlnet... (感觉是哪个地方搬来的代码) + ctrlnet = instantiate_from_config(args.ctrlnet) + model_dict['ctrlnet'] = ctrlnet + + return model_dict + +# def get_models(args): +# unet = instantiate_from_config(args.unet) +# model_dict = { +# 'unet': unet, +# } + +# if args.get('vae'): +# vae = instantiate_from_config(args.vae) +# model_dict['vae'] = vae + +# if args.get('text_model'): +# text_model = instantiate_from_config(args.text_model) +# model_dict['text_model'] = text_model + +# if args.get('ctrlnet'): # 这边还可以加ctrlnet... (感觉是哪个地方搬来的代码) +# ctrlnet = instantiate_from_config(args.ctrlnet) +# model_dict['ctrlnet'] = ctrlnet + +# return model_dict + +def get_DDPM(diffusion_configs, log_args={}, **models): + diffusion_model_class = diffusion_configs['target'] + diffusion_args = diffusion_configs['params'] + DDPM_model = get_obj_from_str(diffusion_model_class) # pl_trainer.instruct_p2p_video.InstructP2PVideoTrainerTemporal + ddpm_model = DDPM_model( + log_args=log_args, + **models, + **diffusion_args + ) + return ddpm_model + + +def get_logger(args): + wandb_logger = WandbLogger( + project=args["expt_name"], + ) + return wandb_logger + +def get_callbacks(args, wandb_logger): + callbacks = [] + for callback in args['callbacks']: + if callback.get('require_wandb', False): + # we need to pass wandb logger to the callback + callback_obj = get_obj_from_str(callback.target) + callbacks.append( + callback_obj(wandb_logger=wandb_logger, **callback.params) + ) + else: + callbacks.append( + instantiate_from_config(callback) + ) + return callbacks + +def get_dataset(args): + from torch.utils.data import DataLoader + data_args = args['data'] + # import pdb; pdb.set_trace() + train_set = instantiate_from_config(data_args['train']) + val_set = instantiate_from_config(data_args['val']) + # import pdb; pdb.set_trace() + # import pdb; pdb.set_trace() + train_loader = DataLoader( + train_set, batch_size=data_args['batch_size'], shuffle=True, + num_workers=4*len(args['trainer_args']['devices']), pin_memory=True + ) + val_loader = DataLoader( + val_set, batch_size=data_args['val_batch_size'], + num_workers=len(args['trainer_args']['devices']), pin_memory=True + ) # 不shuffle + return train_loader, val_loader, train_set, val_set + +def unit_test_create_model(config_path): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + conf = OmegaConf.load(config_path) + models = get_ic_models(conf) + ddpm = get_DDPM(conf['diffusion'], log_args=conf, **models) + ddpm = ddpm.to(device) + return ddpm + +def unit_test_create_dataset(config_path, split='train'): + device = 'cuda' if torch.cuda.is_available() else 'cpu' + conf = OmegaConf.load(config_path) + train_loader, val_loader, train_set, val_set = get_dataset(conf) + if split == 'train': + batch = next(iter(train_loader)) + else: + batch = next(iter(val_loader)) + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.to(device) + return batch + +def unit_test_training_step(config_path): + ddpm = unit_test_create_model(config_path) + batch = unit_test_create_dataset(config_path) + res = ddpm.training_step(batch, 0) + return res + +def unit_test_val_step(config_path): + ddpm = unit_test_create_model(config_path) + batch = unit_test_create_dataset(config_path, split='val') + res = ddpm.validation_step(batch, 0) + return res + +NEGATIVE_PROMPTS = "(((deformed))), blurry, bad anatomy, disfigured, poorly drawn face, mutation, mutated, (extra_limb), (ugly), (poorly drawn hands), fused fingers, messy drawing, broken legs censor, censored, censor_bar, multiple breasts, (mutated hands and fingers:1.5), (long body :1.3), (mutation, poorly drawn :1.2), black-white, bad anatomy, liquid body, liquidtongue, disfigured, malformed, mutated, anatomical nonsense, text font ui, error, malformed hands, long neck, blurred, lowers, low res, bad anatomy, bad proportions, bad shadow, uncoordinated body, unnatural body, fused breasts, bad breasts, huge breasts, poorly drawn breasts, extra breasts, liquid breasts, heavy breasts, missingbreasts, huge haunch, huge thighs, huge calf, bad hands, fused hand, missing hand, disappearing arms, disappearing thigh, disappearing calf, disappearing legs, fusedears, bad ears, poorly drawn ears, extra ears, liquid ears, heavy ears, missing ears, old photo, low res, black and white, black and white filter, colorless" diff --git a/misc_utils/video_ptp_utils.py b/misc_utils/video_ptp_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..40b46f5f6767d057b6b12e757363f769eee9abbe --- /dev/null +++ b/misc_utils/video_ptp_utils.py @@ -0,0 +1,97 @@ +import torch +from modules.damo_text_to_video.unet_sd import UNetSD +from misc_utils.train_utils import instantiate_from_config +from omegaconf import OmegaConf +from modules.damo_text_to_video.text_model import FrozenOpenCLIPEmbedder +from typing import List, Tuple, Union +from diff_match_patch import diff_match_patch +import difflib +from misc_utils.ptp_utils import get_text_embedding_openclip, encode_text_openclip, Text, Edit, Insert, Delete + + +def get_models_of_damo_model( + unet_config: str, + unet_ckpt: str, + vae_config: str, + vae_ckpt: str, + text_model_ckpt: str, +): + vae_conf = OmegaConf.load(vae_config) + unet_conf = OmegaConf.load(unet_config) + + vae = instantiate_from_config(vae_conf.vae) + vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu')) + vae = vae.half().cuda() + + unet = UNetSD(**unet_conf.model.model_cfg) + unet.load_state_dict(torch.load(unet_ckpt, map_location='cpu')) + unet = unet.half().cuda() + + text_model = FrozenOpenCLIPEmbedder(version=text_model_ckpt, layer='penultimate') + text_model = text_model.half().cuda() + + return vae, unet, text_model + +def compute_diff_old(old_sentence: str, new_sentence: str) -> List[Tuple[Union[Text, Edit, Insert, Delete], str, str]]: + dmp = diff_match_patch() + diff = dmp.diff_main(old_sentence, new_sentence) + dmp.diff_cleanupSemantic(diff) + + result = [] + i = 0 + while i < len(diff): + op, data = diff[i] + if op == 0: # Equal + # result.append((Text, data, data)) + result.append(Text(text=data)) + elif op == -1: # Delete + if i + 1 < len(diff) and diff[i + 1][0] == 1: # If next operation is Insert + result.append(Edit(old=data, new=diff[i + 1][1])) # Append as Edit operation + i += 1 # Skip next operation because we've handled it here + else: + result.append(Delete(text=data)) + elif op == 1: # Insert + if i == 0 or diff[i - 1][0] != -1: # If previous operation wasn't Delete + result.append(Insert(text=data)) + i += 1 + + return result + +def compute_diff(old_sentence: str, new_sentence: str) -> List[Union[Text, Edit, Insert, Delete]]: + differ = difflib.Differ() + diff = list(differ.compare(old_sentence.split(), new_sentence.split())) + + result = [] + i = 0 + while i < len(diff): + if diff[i][0] == ' ': # Equal + equal_words = [diff[i][2:]] + while i + 1 < len(diff) and diff[i + 1][0] == ' ': + i += 1 + equal_words.append(diff[i][2:]) + result.append(Text(text=' '.join(equal_words))) + elif diff[i][0] == '-': # Delete + deleted_words = [diff[i][2:]] + while i + 1 < len(diff) and diff[i + 1][0] == '-': + i += 1 + deleted_words.append(diff[i][2:]) + result.append(Delete(text=' '.join(deleted_words))) + elif diff[i][0] == '+': # Insert + inserted_words = [diff[i][2:]] + while i + 1 < len(diff) and diff[i + 1][0] == '+': + i += 1 + inserted_words.append(diff[i][2:]) + result.append(Insert(text=' '.join(inserted_words))) + i += 1 + + # Post-process to merge adjacent inserts and deletes into edits + i = 0 + while i < len(result) - 1: + if isinstance(result[i], Delete) and isinstance(result[i+1], Insert): + result[i:i+2] = [Edit(old=result[i].text, new=result[i+1].text)] + elif isinstance(result[i], Insert) and isinstance(result[i+1], Delete): + result[i:i+2] = [Edit(old=result[i+1].text, new=result[i].text)] + else: + i += 1 + + return result \ No newline at end of file diff --git a/modules/damo_text_to_video/configuration.json b/modules/damo_text_to_video/configuration.json new file mode 100644 index 0000000000000000000000000000000000000000..70f08ae818b74d3435c0adeed2729f39b2ff33b5 --- /dev/null +++ b/modules/damo_text_to_video/configuration.json @@ -0,0 +1,31 @@ +{ "framework": "pytorch", + "task": "text-to-video-synthesis", + "model": { + "type": "latent-text-to-video-synthesis", + "model_args": { + "ckpt_clip": "open_clip_pytorch_model.bin", + "ckpt_unet": "text2video_pytorch_model.pth", + "ckpt_autoencoder": "VQGAN_autoencoder.pth", + "max_frames": 16, + "tiny_gpu": 1 + }, + "model_cfg": { + "in_dim": 4, + "dim": 320, + "y_dim": 768, + "context_dim": 1024, + "out_dim": 4, + "dim_mult": [1, 2, 4, 4], + "num_heads": 8, + "head_dim": 64, + "num_res_blocks": 2, + "attn_scales": [1, 0.5, 0.25], + "dropout": 0.1, + "temporal_attention": "True", + "use_checkpoint": "True" + } + }, + "pipeline": { + "type": "latent-text-to-video-synthesis" + } +} \ No newline at end of file diff --git a/modules/damo_text_to_video/text_model.py b/modules/damo_text_to_video/text_model.py new file mode 100644 index 0000000000000000000000000000000000000000..2be16259e423a0e7e0b6c5b2332a0d3b02006577 --- /dev/null +++ b/modules/damo_text_to_video/text_model.py @@ -0,0 +1,63 @@ +import torch +import open_clip + +class FrozenOpenCLIPEmbedder(torch.nn.Module): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = ['last', 'penultimate'] + + def __init__(self, + arch='ViT-H-14', + version='open_clip_pytorch_model.bin', + device='cuda', + max_length=77, + freeze=True, + layer='last'): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms( + arch, device=torch.device('cpu'), pretrained=version) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == 'last': + self.layer_idx = 0 + elif self.layer == 'penultimate': + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask=None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) \ No newline at end of file diff --git a/modules/damo_text_to_video/unet_sd.py b/modules/damo_text_to_video/unet_sd.py new file mode 100644 index 0000000000000000000000000000000000000000..a869d4db5c5b3319b54932a4c4c5633eb32b1edc --- /dev/null +++ b/modules/damo_text_to_video/unet_sd.py @@ -0,0 +1,1128 @@ +# Part of the implementation is borrowed and modified from stable-diffusion, +# publicly avaialbe at https://github.com/Stability-AI/stablediffusion. +# Copyright 2021-2022 The Alibaba Fundamental Vision Team Authors. All rights reserved. + +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, repeat + +__all__ = ['UNetSD'] + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +class UNetSD(nn.Module): + + def __init__(self, + in_dim=7, + dim=512, + y_dim=512, + context_dim=512, + out_dim=6, + dim_mult=[1, 2, 3, 4], + num_heads=None, + head_dim=64, + num_res_blocks=3, + attn_scales=[1 / 2, 1 / 4, 1 / 8], + use_scale_shift_norm=True, + dropout=0.1, + temporal_attn_times=2, + temporal_attention=True, + use_checkpoint=False, + use_image_dataset=False, + use_fps_condition=False, + use_sim_mask=False): + embed_dim = dim * 4 + num_heads = num_heads if num_heads else dim // 32 + super(UNetSD, self).__init__() + self.in_dim = in_dim + self.dim = dim + self.y_dim = y_dim + self.context_dim = context_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.dim_mult = dim_mult + self.num_heads = num_heads + # parameters for spatial/temporal attention + self.head_dim = head_dim + self.num_res_blocks = num_res_blocks + self.attn_scales = attn_scales + self.use_scale_shift_norm = use_scale_shift_norm + self.temporal_attn_times = temporal_attn_times + self.temporal_attention = temporal_attention + self.use_checkpoint = use_checkpoint + self.use_image_dataset = use_image_dataset + self.use_fps_condition = use_fps_condition + self.use_sim_mask = use_sim_mask + use_linear_in_temporal = False + transformer_depth = 1 + disabled_sa = False + # params + enc_dims = [dim * u for u in [1] + dim_mult] + dec_dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]] + shortcut_dims = [] + scale = 1.0 + + # embeddings + self.time_embed = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + + if self.use_fps_condition: + self.fps_embedding = nn.Sequential( + nn.Linear(dim, embed_dim), nn.SiLU(), + nn.Linear(embed_dim, embed_dim)) + nn.init.zeros_(self.fps_embedding[-1].weight) + nn.init.zeros_(self.fps_embedding[-1].bias) + + # encoder + self.input_blocks = nn.ModuleList() + init_block = nn.ModuleList([nn.Conv2d(self.in_dim, dim, 3, padding=1)]) + + if temporal_attention: + init_block.append( + TemporalTransformer( + dim, + num_heads, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset)) + + self.input_blocks.append(init_block) + shortcut_dims.append(dim) + for i, (in_dim, + out_dim) in enumerate(zip(enc_dims[:-1], enc_dims[1:])): + for j in range(num_res_blocks): + # residual (+attention) blocks + block = nn.ModuleList([ + ResBlock( + in_dim, + embed_dim, + dropout, + out_channels=out_dim, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + ) + ]) + if scale in attn_scales: + block.append( + SpatialTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=1, + context_dim=self.context_dim, + disable_self_attn=False, + use_linear=True)) + if self.temporal_attention: + block.append( + TemporalTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset)) + + in_dim = out_dim + self.input_blocks.append(block) + shortcut_dims.append(out_dim) + + # downsample + if i != len(dim_mult) - 1 and j == num_res_blocks - 1: + downsample = Downsample( + out_dim, True, dims=2, out_channels=out_dim) + shortcut_dims.append(out_dim) + scale /= 2.0 + self.input_blocks.append(downsample) + + # middle + self.middle_block = nn.ModuleList([ + ResBlock( + out_dim, + embed_dim, + dropout, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + ), + SpatialTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=1, + context_dim=self.context_dim, + disable_self_attn=False, + use_linear=True) + ]) + + if self.temporal_attention: + self.middle_block.append( + TemporalTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset, + )) + + self.middle_block.append( + ResBlock( + out_dim, + embed_dim, + dropout, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + )) + + # decoder + self.output_blocks = nn.ModuleList() + for i, (in_dim, + out_dim) in enumerate(zip(dec_dims[:-1], dec_dims[1:])): + for j in range(num_res_blocks + 1): + # residual (+attention) blocks + block = nn.ModuleList([ + ResBlock( + in_dim + shortcut_dims.pop(), + embed_dim, + dropout, + out_dim, + use_scale_shift_norm=False, + use_image_dataset=use_image_dataset, + ) + ]) + if scale in attn_scales: + block.append( + SpatialTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=1, + context_dim=1024, + disable_self_attn=False, + use_linear=True)) + + if self.temporal_attention: + block.append( + TemporalTransformer( + out_dim, + out_dim // head_dim, + head_dim, + depth=transformer_depth, + context_dim=context_dim, + disable_self_attn=disabled_sa, + use_linear=use_linear_in_temporal, + multiply_zero=use_image_dataset)) + in_dim = out_dim + + # upsample + if i != len(dim_mult) - 1 and j == num_res_blocks: + upsample = Upsample( + out_dim, True, dims=2.0, out_channels=out_dim) + scale *= 2.0 + block.append(upsample) + self.output_blocks.append(block) + + # head + self.out = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), + nn.Conv2d(out_dim, self.out_dim, 3, padding=1)) + + # zero out the last layer params + nn.init.zeros_(self.out[-1].weight) + + def forward( + self, + x, + t, + y, + fps=None, + video_mask=None, + focus_present_mask=None, + prob_focus_present=0., + mask_last_frame_num=0 # mask last frame num + ): + """ + prob_focus_present: probability at which a given batch sample will focus on the present + (0. is all off, 1. is completely arrested attention across time) + """ + batch, device = x.shape[0], x.device + self.batch = batch + + # image and video joint training, if mask_last_frame_num is set, prob_focus_present will be ignored + if mask_last_frame_num > 0: + focus_present_mask = None + video_mask[-mask_last_frame_num:] = False + else: + focus_present_mask = default( + focus_present_mask, lambda: prob_mask_like( + (batch, ), prob_focus_present, device=device)) + + time_rel_pos_bias = None + # embeddings + if self.use_fps_condition and fps is not None: + e = self.time_embed(sinusoidal_embedding( + t, self.dim)) + self.fps_embedding( + sinusoidal_embedding(fps, self.dim)) + else: + e = self.time_embed(sinusoidal_embedding(t, self.dim)) + context = y + + # repeat f times for spatial e and context + f = x.shape[2] + e = e.repeat_interleave(repeats=f, dim=0) + if isinstance(context, (tuple, list)): + context = ( + context[0].repeat_interleave(repeats=f, dim=0), + context[1].repeat_interleave(repeats=f, dim=0), + ) + else: + context = context.repeat_interleave(repeats=f, dim=0) + + # always in shape (b f) c h w, except for temporal layer + x = rearrange(x, 'b c f h w -> (b f) c h w') + # encoder + xs = [] + for block in self.input_blocks: + x = self._forward_single(block, x, e, context, time_rel_pos_bias, + focus_present_mask, video_mask) + xs.append(x) + + # middle + for block in self.middle_block: + x = self._forward_single(block, x, e, context, time_rel_pos_bias, + focus_present_mask, video_mask) + + # decoder + for block in self.output_blocks: + x = torch.cat([x, xs.pop()], dim=1) + x = self._forward_single( + block, + x, + e, + context, + time_rel_pos_bias, + focus_present_mask, + video_mask, + reference=xs[-1] if len(xs) > 0 else None) + + # head + x = self.out(x) + # reshape back to (b c f h w) + x = rearrange(x, '(b f) c h w -> b c f h w', b=batch) + return x + + def _forward_single(self, + module, + x, + e, + context, + time_rel_pos_bias, + focus_present_mask, + video_mask, + reference=None): + if isinstance(module, ResidualBlock): + x = x.contiguous() + x = module(x, e, reference) + elif isinstance(module, ResBlock): + x = x.contiguous() + x = module(x, e, self.batch) + elif isinstance(module, SpatialTransformer): + x = module(x, context) + elif isinstance(module, TemporalTransformer): + x = rearrange(x, '(b f) c h w -> b c f h w', b=self.batch) + x = module(x, context) + x = rearrange(x, 'b c f h w -> (b f) c h w') + elif isinstance(module, CrossAttention): + x = module(x, context) + elif isinstance(module, BasicTransformerBlock): + x = module(x, context) + elif isinstance(module, FeedForward): + x = module(x, context) + elif isinstance(module, Upsample): + x = module(x) + elif isinstance(module, Downsample): + x = module(x) + elif isinstance(module, Resample): + x = module(x, reference) + elif isinstance(module, nn.ModuleList): + for block in module: + x = self._forward_single(block, x, e, context, + time_rel_pos_bias, focus_present_mask, + video_mask, reference) + else: + x = module(x) + return x + + +def sinusoidal_embedding(timesteps, dim): + # check input + half = dim // 2 + timesteps = timesteps.float() + # compute sinusoidal embedding + sinusoid = torch.outer( + timesteps, torch.pow(10000, + -torch.arange(half).to(timesteps).div(half))) + x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1) + if dim % 2 != 0: + x = torch.cat([x, torch.zeros_like(x[:, :1])], dim=1) + return x + + +class CrossAttention(nn.Module): + + def __init__(self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.): + super().__init__() + inner_dim = dim_head * heads + context_dim = default(context_dim, query_dim) + + self.scale = dim_head**-0.5 + self.heads = heads + + self.to_q = nn.Linear(query_dim, inner_dim, bias=False) + self.to_k = nn.Linear(context_dim, inner_dim, bias=False) + self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + + self.to_out = nn.Sequential( + nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)) + + self.ptp_sa_replace = False + self.num_frames = 1 # for ptp sa replacement use + + def forward(self, x, context=None, mask=None): + h = self.heads + + q = self.to_q(x) + + is_self_attn = context is None + context = default(context, x) + if (isinstance(context, list) or isinstance(context, tuple)): + k = self.to_k(context[0]) # use old prompt's new mapping in new prompt for key + v = self.to_v(context[1]) # use new prompt for value + else: + k = self.to_k(context) + v = self.to_v(context) + + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), + (q, k, v)) + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + del q, k + + if is_self_attn and self.ptp_sa_replace: #and x.shape[0] < x.shape[1]: + if x.shape[0] < x.shape[1]: + # spatial attention + sim = rearrange(sim, '(b f h) l d -> b f h l d', b=4, f=self.num_frames, h=h) + sims = sim.chunk(4) + sim = torch.cat((sims[0], sims[0], sims[2], sims[2])) + sim = rearrange(sim, 'b f h l d -> (b f h) l d') + else: + # pass + # temporal attention + sim = rearrange(sim, '(b l) f d -> b l f d', b=4) + sims = sim.chunk(4) + sim = torch.cat((sims[0], sims[0], sims[2], sims[2])) + sim = rearrange(sim, 'b l f d -> (b l) f d') + + if exists(mask): + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=h) + sim.masked_fill_(~mask, max_neg_value) + + # attention, what we cannot get enough of + sim = sim.softmax(dim=-1) + + out = torch.einsum('b i j, b j d -> b i d', sim, v) + out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + return self.to_out(out) + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data in spatial axis. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + NEW: use_linear for more efficiency instead of the 1x1 convs + """ + + def __init__(self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0., + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True): + super().__init__() + if exists(context_dim) and not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + disable_self_attn=disable_self_attn, + checkpoint=use_checkpoint) for d in range(depth) + ]) + if not use_linear: + self.proj_out = zero_module( + nn.Conv2d( + inner_dim, in_channels, kernel_size=1, stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if not isinstance(context, list): + context = [context] + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + if not self.use_linear: + x = self.proj_in(x) + x = rearrange(x, 'b c h w -> b (h w) c').contiguous() + if self.use_linear: + x = self.proj_in(x) + for i, block in enumerate(self.transformer_blocks): + x = block(x, context=context[i]) + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) c -> b c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = self.proj_out(x) + return x + x_in + + +class TemporalTransformer(nn.Module): + """ + Transformer block for image-like data in temporal axis. + First, reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__(self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0., + context_dim=None, + disable_self_attn=False, + use_linear=False, + use_checkpoint=True, + only_self_att=True, + multiply_zero=False): + super().__init__() + self.multiply_zero = multiply_zero + self.only_self_att = only_self_att + if self.only_self_att: + context_dim = None + if not isinstance(context_dim, list): + context_dim = [context_dim] + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + if not use_linear: + self.proj_in = nn.Conv1d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + else: + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList([ + BasicTransformerBlock( + inner_dim, + n_heads, + d_head, + dropout=dropout, + context_dim=context_dim[d], + checkpoint=use_checkpoint) for d in range(depth) + ]) + if not use_linear: + self.proj_out = zero_module( + nn.Conv1d( + inner_dim, in_channels, kernel_size=1, stride=1, + padding=0)) + else: + self.proj_out = zero_module(nn.Linear(in_channels, inner_dim)) + self.use_linear = use_linear + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + if self.only_self_att: + context = None + if not isinstance(context, list): + context = [context] + b, c, f, h, w = x.shape + x_in = x + x = self.norm(x) + + if not self.use_linear: + x = rearrange(x, 'b c f h w -> (b h w) c f').contiguous() + x = self.proj_in(x) + if self.use_linear: + x = rearrange( + x, '(b f) c h w -> b (h w) f c', f=self.frames).contiguous() + x = self.proj_in(x) + + if self.only_self_att: + x = rearrange(x, 'bhw c f -> bhw f c').contiguous() + for i, block in enumerate(self.transformer_blocks): + x = block(x) + x = rearrange(x, '(b hw) f c -> b hw f c', b=b).contiguous() + else: + x = rearrange(x, '(b hw) c f -> b hw f c', b=b).contiguous() + for i, block in enumerate(self.transformer_blocks): + context[i] = rearrange( + context[i], '(b f) l con -> b f l con', + f=self.frames).contiguous() + # calculate each batch one by one (since number in shape could not greater then 65,535 for some package) + for j in range(b): + context_i_j = repeat( + context[i][j], + 'f l con -> (f r) l con', + r=(h * w) // self.frames, + f=self.frames).contiguous() + x[j] = block(x[j], context=context_i_j) + + if self.use_linear: + x = self.proj_out(x) + x = rearrange(x, 'b (h w) f c -> b f c h w', h=h, w=w).contiguous() + if not self.use_linear: + x = rearrange(x, 'b hw f c -> (b hw) c f').contiguous() + x = self.proj_out(x) + x = rearrange( + x, '(b h w) c f -> b c f h w', b=b, h=h, w=w).contiguous() + + if self.multiply_zero: + x = 0.0 * x + x_in + else: + x = x + x_in + return x + + +class BasicTransformerBlock(nn.Module): + + def __init__(self, + dim, + n_heads, + d_head, + dropout=0., + context_dim=None, + gated_ff=True, + checkpoint=True, + disable_self_attn=False): + super().__init__() + attn_cls = CrossAttention + self.disable_self_attn = disable_self_attn + self.attn1 = attn_cls( + query_dim=dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + context_dim=context_dim if self.disable_self_attn else + None) # is a self-attention if not self.disable_self_attn + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = attn_cls( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + x = self.attn1( + self.norm1(x), + context=context if self.disable_self_attn else None) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +# feedforward +class GEGLU(nn.Module): + + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class FeedForward(nn.Module): + + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = nn.Sequential(nn.Linear( + dim, inner_dim), nn.GELU()) if not glu else GEGLU(dim, inner_dim) + + self.net = nn.Sequential(project_in, nn.Dropout(dropout), + nn.Linear(inner_dim, dim_out)) + + def forward(self, x): + return self.net(x) + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + use_conv, + dims=2, + out_channels=None, + padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = nn.Conv2d( + self.channels, self.out_channels, 3, padding=padding) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), + mode='nearest') + else: + x = F.interpolate(x, scale_factor=2, mode='nearest') + if self.use_conv: + x = self.conv(x) + return x + + +class ResBlock(nn.Module): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + :param use_temporal_conv: if True, use the temporal convolution. + :param use_image_dataset: if True, the temporal parameters will not be optimized. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + up=False, + down=False, + use_temporal_conv=True, + use_image_dataset=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_scale_shift_norm = use_scale_shift_norm + self.use_temporal_conv = use_temporal_conv + + self.in_layers = nn.Sequential( + nn.GroupNorm(32, channels), + nn.SiLU(), + nn.Conv2d(channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + nn.Linear( + emb_channels, + 2 * self.out_channels + if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + nn.GroupNorm(32, self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + nn.Conv2d(self.out_channels, self.out_channels, 3, padding=1)), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1) + else: + self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) + + if self.use_temporal_conv: + self.temopral_conv = TemporalConvBlock_v2( + self.out_channels, + self.out_channels, + dropout=0.1, + use_image_dataset=use_image_dataset) + + def forward(self, x, emb, batch_size): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return self._forward(x, emb, batch_size) + + def _forward(self, x, emb, batch_size): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = torch.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + h = self.skip_connection(x) + h + + if self.use_temporal_conv: + h = rearrange(h, '(b f) c h w -> b c f h w', b=batch_size) + h = self.temopral_conv(h) + h = rearrange(h, 'b c f h w -> (b f) c h w') + return h + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, + channels, + use_conv, + dims=2, + out_channels=None, + padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if self.use_conv: + self.op = nn.Conv2d( + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class Resample(nn.Module): + + def __init__(self, in_dim, out_dim, mode): + assert mode in ['none', 'upsample', 'downsample'] + super(Resample, self).__init__() + self.in_dim = in_dim + self.out_dim = out_dim + self.mode = mode + + def forward(self, x, reference=None): + if self.mode == 'upsample': + assert reference is not None + x = F.interpolate(x, size=reference.shape[-2:], mode='nearest') + elif self.mode == 'downsample': + x = F.adaptive_avg_pool2d( + x, output_size=tuple(u // 2 for u in x.shape[-2:])) + return x + + +class ResidualBlock(nn.Module): + + def __init__(self, + in_dim, + embed_dim, + out_dim, + use_scale_shift_norm=True, + mode='none', + dropout=0.0): + super(ResidualBlock, self).__init__() + self.in_dim = in_dim + self.embed_dim = embed_dim + self.out_dim = out_dim + self.use_scale_shift_norm = use_scale_shift_norm + self.mode = mode + + # layers + self.layer1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv2d(in_dim, out_dim, 3, padding=1)) + self.resample = Resample(in_dim, in_dim, mode) + self.embedding = nn.Sequential( + nn.SiLU(), + nn.Linear(embed_dim, + out_dim * 2 if use_scale_shift_norm else out_dim)) + self.layer2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv2d(out_dim, out_dim, 3, padding=1)) + self.shortcut = nn.Identity() if in_dim == out_dim else nn.Conv2d( + in_dim, out_dim, 1) + # zero out the last layer params + nn.init.zeros_(self.layer2[-1].weight) + + def forward(self, x, e, reference=None): + identity = self.resample(x, reference) + x = self.layer1[-1](self.resample(self.layer1[:-1](x), reference)) + e = self.embedding(e).unsqueeze(-1).unsqueeze(-1).type(x.dtype) + if self.use_scale_shift_norm: + scale, shift = e.chunk(2, dim=1) + x = self.layer2[0](x) * (1 + scale) + shift + x = self.layer2[1:](x) + else: + x = x + e + x = self.layer2(x) + x = x + self.shortcut(identity) + return x + + +class AttentionBlock(nn.Module): + + def __init__(self, dim, context_dim=None, num_heads=None, head_dim=None): + # consider head_dim first, then num_heads + num_heads = dim // head_dim if head_dim else num_heads + head_dim = dim // num_heads + assert num_heads * head_dim == dim + super(AttentionBlock, self).__init__() + self.dim = dim + self.context_dim = context_dim + self.num_heads = num_heads + self.head_dim = head_dim + self.scale = math.pow(head_dim, -0.25) + + # layers + self.norm = nn.GroupNorm(32, dim) + self.to_qkv = nn.Conv2d(dim, dim * 3, 1) + if context_dim is not None: + self.context_kv = nn.Linear(context_dim, dim * 2) + self.proj = nn.Conv2d(dim, dim, 1) + + # zero out the last layer params + nn.init.zeros_(self.proj.weight) + + def forward(self, x, context=None): + r"""x: [B, C, H, W]. + context: [B, L, C] or None. + """ + identity = x + b, c, h, w, n, d = *x.size(), self.num_heads, self.head_dim + + # compute query, key, value + x = self.norm(x) + q, k, v = self.to_qkv(x).view(b, n * 3, d, h * w).chunk(3, dim=1) + if context is not None: + ck, cv = self.context_kv(context).reshape(b, -1, n * 2, + d).permute(0, 2, 3, + 1).chunk( + 2, dim=1) + k = torch.cat([ck, k], dim=-1) + v = torch.cat([cv, v], dim=-1) + + # compute attention + attn = torch.matmul(q.transpose(-1, -2) * self.scale, k * self.scale) + attn = F.softmax(attn, dim=-1) + + # gather context + x = torch.matmul(v, attn.transpose(-1, -2)) + x = x.reshape(b, c, h, w) + # output + x = self.proj(x) + return x + identity + + +class TemporalConvBlock_v2(nn.Module): + + def __init__(self, + in_dim, + out_dim=None, + dropout=0.0, + use_image_dataset=False): + super(TemporalConvBlock_v2, self).__init__() + if out_dim is None: + out_dim = in_dim # int(1.5*in_dim) + self.in_dim = in_dim + self.out_dim = out_dim + self.use_image_dataset = use_image_dataset + + # conv layers + self.conv1 = nn.Sequential( + nn.GroupNorm(32, in_dim), nn.SiLU(), + nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv2 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv3 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + self.conv4 = nn.Sequential( + nn.GroupNorm(32, out_dim), nn.SiLU(), nn.Dropout(dropout), + nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0))) + + # zero out the last layer params,so the conv block is identity + nn.init.zeros_(self.conv4[-1].weight) + nn.init.zeros_(self.conv4[-1].bias) + + def forward(self, x): + identity = x + x = self.conv1(x) + x = self.conv2(x) + x = self.conv3(x) + x = self.conv4(x) + + if self.use_image_dataset: + x = identity + 0.0 * x + else: + x = identity + x + return x + + +def prob_mask_like(shape, prob, device): + if prob == 1: + return torch.ones(shape, device=device, dtype=torch.bool) + elif prob == 0: + return torch.zeros(shape, device=device, dtype=torch.bool) + else: + mask = torch.zeros(shape, device=device).float().uniform_(0, 1) < prob + # aviod mask all, which will cause find_unused_parameters error + if mask.all(): + mask[0] = False + return mask + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f'unsupported dimensions: {dims}') + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f'unsupported dimensions: {dims}') \ No newline at end of file diff --git a/modules/kl_autoencoder/__pycache__/autoencoder.cpython-310.pyc b/modules/kl_autoencoder/__pycache__/autoencoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96d09e63c31f6d00abe0b2d46aaa940dd1174da4 Binary files /dev/null and b/modules/kl_autoencoder/__pycache__/autoencoder.cpython-310.pyc differ diff --git a/modules/kl_autoencoder/__pycache__/autoencoder.cpython-38.pyc b/modules/kl_autoencoder/__pycache__/autoencoder.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8608309e117c6c13fcf2a365842fb0d11de3d996 Binary files /dev/null and b/modules/kl_autoencoder/__pycache__/autoencoder.cpython-38.pyc differ diff --git a/modules/kl_autoencoder/autoencoder.py b/modules/kl_autoencoder/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..059fd22f696744b1531129e8ac1480800521642f --- /dev/null +++ b/modules/kl_autoencoder/autoencoder.py @@ -0,0 +1,190 @@ +import torch +import pytorch_lightning as pl +import torch.nn.functional as F +import numpy as np + +from modules.vqvae.model import Encoder, Decoder + +from misc_utils.model_utils import instantiate_from_config + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.]) + else: + if other is None: + return 0.5 * torch.sum(torch.pow(self.mean, 2) + + self.var - 1.0 - self.logvar, + dim=[1, 2, 3]) + else: + return 0.5 * torch.sum( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var - 1.0 - self.logvar + other.logvar, + dim=[1, 2, 3]) + + def nll(self, sample, dims=[1,2,3]): + if self.deterministic: + return torch.Tensor([0.]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims) + + def mode(self): + return self.mean + +class AutoencoderKL(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + ): + super().__init__() + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + assert ddconfig["double_z"] + self.quant_conv = torch.nn.Conv2d(2*ddconfig["z_channels"], 2*embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + self.embed_dim = embed_dim + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + self.load_state_dict(sd, strict=False) + print(f"Restored from {path}") + + def encode(self, x): + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + # TODO check if need to put sample into DDIM_ldm class + enc = posterior.sample() + return enc #posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + return dec + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + dec = self.decode(z) + return dec, posterior + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + + if optimizer_idx == 0: + # train encoder+decoder+logvar + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log("aeloss", aeloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return aeloss + + if optimizer_idx == 1: + # train the discriminator + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + + self.log("discloss", discloss, prog_bar=True, logger=True, on_step=True, on_epoch=True) + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=False) + return discloss + + def validation_step(self, batch, batch_idx): + inputs = self.get_input(batch, self.image_key) + reconstructions, posterior = self(inputs) + aeloss, log_dict_ae = self.loss(inputs, reconstructions, posterior, 0, self.global_step, + last_layer=self.get_last_layer(), split="val") + + discloss, log_dict_disc = self.loss(inputs, reconstructions, posterior, 1, self.global_step, + last_layer=self.get_last_layer(), split="val") + + self.log("val/rec_loss", log_dict_ae["val/rec_loss"]) + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr = self.learning_rate + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr, betas=(0.5, 0.9)) + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + @torch.no_grad() + def log_images(self, batch, only_inputs=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if not only_inputs: + xrec, posterior = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["samples"] = self.decode(torch.randn_like(posterior.sample())) + log["reconstructions"] = xrec + log["inputs"] = x + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x \ No newline at end of file diff --git a/modules/openclip/__pycache__/modules.cpython-310.pyc b/modules/openclip/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23c7741d5bbe34cc6b78309169b1334c7890f984 Binary files /dev/null and b/modules/openclip/__pycache__/modules.cpython-310.pyc differ diff --git a/modules/openclip/__pycache__/modules.cpython-38.pyc b/modules/openclip/__pycache__/modules.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4180f8f98f5e61ea736886c38dfee3990924c6fd Binary files /dev/null and b/modules/openclip/__pycache__/modules.cpython-38.pyc differ diff --git a/modules/openclip/modules.py b/modules/openclip/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..34414a78487df797a71d60fbd56ca7a196a1b51a --- /dev/null +++ b/modules/openclip/modules.py @@ -0,0 +1,225 @@ +from typing import Any, Mapping +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint + +from transformers import T5Tokenizer, T5EncoderModel, CLIPTokenizer, CLIPTextModel + +import open_clip + +import os +import json + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class IdentityEncoder(AbstractEncoder): + + def encode(self, x): + return x + + +class ClassEmbedder(nn.Module): + def __init__(self, embed_dim, n_classes=1000, key='class', ucg_rate=0.1): + super().__init__() + self.key = key + self.embedding = nn.Embedding(n_classes, embed_dim) + self.n_classes = n_classes + self.ucg_rate = ucg_rate + + def forward(self, batch, key=None, disable_dropout=False): + if key is None: + key = self.key + # this is for use in crossattn + c = batch[key][:, None] + if self.ucg_rate > 0. and not disable_dropout: + mask = 1. - torch.bernoulli(torch.ones_like(c) * self.ucg_rate) + c = mask * c + (1-mask) * torch.ones_like(c)*(self.n_classes-1) + c = c.long() + c = self.embedding(c) + return c + + def get_unconditional_conditioning(self, bs, device="cuda"): + uc_class = self.n_classes - 1 # 1000 classes --> 0 ... 999, one extra class for ucg (class 1000) + uc = torch.ones((bs,), device=device) * uc_class + uc = {self.key: uc} + return uc + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class FrozenT5Embedder(AbstractEncoder): + """Uses the T5 transformer encoder for text""" + def __init__(self, version="google/t5-v1_1-large", device="cuda", max_length=77, freeze=True): # others are google/t5-v1_1-xl and google/t5-v1_1-xxl + super().__init__() + self.tokenizer = T5Tokenizer.from_pretrained(version) + self.transformer = T5EncoderModel.from_pretrained(version) + self.device = device + self.max_length = max_length # TODO: typical value? + if freeze: + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + #self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from huggingface)""" + LAYERS = [ + "last", + "pooled", + "hidden" + ] + def __init__(self, version="openai/clip-vit-large-patch14", device="cuda", max_length=77, + freeze=True, layer="last", layer_idx=None, base_path=None, inference=False): # clip-vit-base-patch32 + super().__init__() + assert layer in self.LAYERS + + if base_path: + self.tokenizer = CLIPTokenizer.from_pretrained(os.path.join(base_path, 'tokenizer')) + self.transformer = CLIPTextModel.from_pretrained(os.path.join(base_path, 'text_encoder')) + else: + self.tokenizer = CLIPTokenizer.from_pretrained(version) + self.transformer = CLIPTextModel.from_pretrained(version) + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + self.layer_idx = layer_idx + if layer == "hidden": + assert layer_idx is not None + assert 0 <= abs(layer_idx) <= 12 + + def freeze(self): + self.transformer = self.transformer.eval() + #self.train = disabled_train + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens, output_hidden_states=self.layer=="hidden") + if self.layer == "last": + z = outputs.last_hidden_state + elif self.layer == "pooled": + z = outputs.pooler_output[:, None, :] + else: + z = outputs.hidden_states[self.layer_idx] + return z + + def encode(self, text): + return self(text) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + if "transformer.text_model.embeddings.position_ids" in state_dict: + state_dict.pop("transformer.text_model.embeddings.position_ids") # it seems that this is removed from the model in recent transformers versions + return super().load_state_dict(state_dict, strict) + + +class FrozenOpenCLIPEmbedder(AbstractEncoder): + """ + Uses the OpenCLIP transformer encoder for text + """ + LAYERS = [ + #"pooled", + "last", + "penultimate" + ] + def __init__(self, arch="ViT-H-14", version="laion2b_s32b_b79k", device="cuda", max_length=77, + freeze=True, layer="last"): + super().__init__() + assert layer in self.LAYERS + model, _, _ = open_clip.create_model_and_transforms(arch, device=torch.device('cpu'), pretrained=version) + del model.visual + self.model = model + + self.device = device + self.max_length = max_length + if freeze: + self.freeze() + self.layer = layer + if self.layer == "last": + self.layer_idx = 0 + elif self.layer == "penultimate": + self.layer_idx = 1 + else: + raise NotImplementedError() + + def freeze(self): + self.model = self.model.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + tokens = open_clip.tokenize(text) + z = self.encode_with_transformer(tokens.to(self.device)) + return z + + def encode_with_transformer(self, text): + x = self.model.token_embedding(text) # [batch_size, n_ctx, d_model] + x = x + self.model.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_transformer_forward(x, attn_mask=self.model.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.model.ln_final(x) + return x + + def text_transformer_forward(self, x: torch.Tensor, attn_mask = None): + for i, r in enumerate(self.model.transformer.resblocks): + if i == len(self.model.transformer.resblocks) - self.layer_idx: + break + if self.model.transformer.grad_checkpointing and not torch.jit.is_scripting(): + x = checkpoint(r, x, attn_mask) + else: + x = r(x, attn_mask=attn_mask) + return x + + def encode(self, text): + return self(text) + + +class FrozenCLIPT5Encoder(AbstractEncoder): + def __init__(self, clip_version="openai/clip-vit-large-patch14", t5_version="google/t5-v1_1-xl", device="cuda", + clip_max_length=77, t5_max_length=77): + super().__init__() + self.clip_encoder = FrozenCLIPEmbedder(clip_version, device, max_length=clip_max_length) + self.t5_encoder = FrozenT5Embedder(t5_version, device, max_length=t5_max_length) + # print(f"{self.clip_encoder.__class__.__name__} has {count_params(self.clip_encoder)*1.e-6:.2f} M parameters, " + # f"{self.t5_encoder.__class__.__name__} comes with {count_params(self.t5_encoder)*1.e-6:.2f} M params.") + + def encode(self, text): + return self(text) + + def forward(self, text): + clip_z = self.clip_encoder.encode(text) + t5_z = self.t5_encoder.encode(text) + return [clip_z, t5_z] + + diff --git a/modules/video_unet_temporal/__pycache__/attention.cpython-310.pyc b/modules/video_unet_temporal/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3583bff2ab2ecc8ed89749490a70cd66bfc6dac Binary files /dev/null and b/modules/video_unet_temporal/__pycache__/attention.cpython-310.pyc differ diff --git a/modules/video_unet_temporal/__pycache__/attention.cpython-38.pyc b/modules/video_unet_temporal/__pycache__/attention.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2f2e0d612535672d4f9ff1e20d46ab49173763f Binary files /dev/null and b/modules/video_unet_temporal/__pycache__/attention.cpython-38.pyc differ diff --git a/modules/video_unet_temporal/__pycache__/motion_module.cpython-310.pyc b/modules/video_unet_temporal/__pycache__/motion_module.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..83141aa63d373ea1628a2144251eb818b8a8e62b Binary files /dev/null and b/modules/video_unet_temporal/__pycache__/motion_module.cpython-310.pyc differ diff --git a/modules/video_unet_temporal/__pycache__/motion_module.cpython-38.pyc b/modules/video_unet_temporal/__pycache__/motion_module.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6babec95609d764c60e9a5693467d54b60a795a Binary files /dev/null and b/modules/video_unet_temporal/__pycache__/motion_module.cpython-38.pyc differ diff --git a/modules/video_unet_temporal/__pycache__/resnet.cpython-310.pyc b/modules/video_unet_temporal/__pycache__/resnet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8342b2c86aca001035af1c9f84a3a9530d23f55 Binary files /dev/null and b/modules/video_unet_temporal/__pycache__/resnet.cpython-310.pyc differ diff --git a/modules/video_unet_temporal/__pycache__/resnet.cpython-38.pyc b/modules/video_unet_temporal/__pycache__/resnet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bef8f4635018a159f81d3a76f0e4b73bb8e2a7ca Binary files /dev/null and b/modules/video_unet_temporal/__pycache__/resnet.cpython-38.pyc differ diff --git a/modules/video_unet_temporal/__pycache__/unet.cpython-310.pyc b/modules/video_unet_temporal/__pycache__/unet.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c852c6671bce99af818dc8268234eaab7476d0d Binary files /dev/null and b/modules/video_unet_temporal/__pycache__/unet.cpython-310.pyc differ diff --git a/modules/video_unet_temporal/__pycache__/unet.cpython-38.pyc b/modules/video_unet_temporal/__pycache__/unet.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c3f5511acd6c98560b747d392e1051dde5d1e29 Binary files /dev/null and b/modules/video_unet_temporal/__pycache__/unet.cpython-38.pyc differ diff --git a/modules/video_unet_temporal/__pycache__/unet_blocks.cpython-310.pyc b/modules/video_unet_temporal/__pycache__/unet_blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b82f4e36dd7a286f7ad0282174d413405cf0cd5 Binary files /dev/null and b/modules/video_unet_temporal/__pycache__/unet_blocks.cpython-310.pyc differ diff --git a/modules/video_unet_temporal/__pycache__/unet_blocks.cpython-38.pyc b/modules/video_unet_temporal/__pycache__/unet_blocks.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52ecf124672d5113f1808d471c8b406c236cf18d Binary files /dev/null and b/modules/video_unet_temporal/__pycache__/unet_blocks.cpython-38.pyc differ diff --git a/modules/video_unet_temporal/attention.py b/modules/video_unet_temporal/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..ab62056fcb8c053d23bbf02978a4e822e757d6f8 --- /dev/null +++ b/modules/video_unet_temporal/attention.py @@ -0,0 +1,286 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py + +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn +from torch import einsum +from misc_utils.model_utils import default, exists + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import Attention, FeedForward, AdaLayerNorm + +from einops import rearrange, repeat + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class Transformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + # Input + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + #! change下一行, 和hd的rmap concat + hdr_latents = None + if isinstance(encoder_hidden_states, dict): + # print("encoder_hidden_states is a dictionary.") + if 'hdr_latents' in encoder_hidden_states: + hdr_latents = encoder_hidden_states['hdr_latents'] + hdr_latents = rearrange(hdr_latents, 'b f n c -> (b f) n c') + if 'encoder_hidden_states' in encoder_hidden_states: + encoder_hidden_states = encoder_hidden_states['encoder_hidden_states'] + # import pdb; pdb.set_trace() + + encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) #! original (16,77,768) + # import pdb; pdb.set_trace() + if hdr_latents is not None: + encoder_hidden_states = torch.cat((encoder_hidden_states, hdr_latents), dim=1) #! change -> (16, 80, 768) + + # import pdb; pdb.set_trace() + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) #(b*f, h*w, c) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length + ) + + # Output + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + )#(b f) c h w + + output = hidden_states + residual + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + + # SC-Attn + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + # Cross-Attn + if cross_attention_dim is not None: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + else: + self.attn2 = None + + if cross_attention_dim is not None: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + else: + self.norm2 = None + + # Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + # import pdb; pdb.set_trace() + self.norm3 = nn.LayerNorm(dim) + + # Temp-Attn + # self.attn_temp = Attention( + # query_dim=dim, + # heads=num_attention_heads, + # dim_head=attention_head_dim, + # dropout=dropout, + # bias=attention_bias, + # upcast_attention=upcast_attention, + # ) + # nn.init.zeros_(self.attn_temp.to_out[0].weight.data) + # self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn1.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers, attention_op) + if self.attn2 is not None: + self.attn2.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers, attention_op) + # self.attn_temp.set_use_memory_efficient_attention_xformers(use_memory_efficient_attention_xformers, attention_op) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): + # SparseCausal-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + if self.only_cross_attention: + hidden_states = ( + self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states + ) + else: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states + + if self.attn2 is not None: + # Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + hidden_states + ) + + # Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + # Temporal-Attention + # d = hidden_states.shape[1] + # hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) + # norm_hidden_states = ( + # self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) + # ) + # hidden_states = self.attn_temp(norm_hidden_states) + hidden_states + # hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states diff --git a/modules/video_unet_temporal/motion_module.py b/modules/video_unet_temporal/motion_module.py new file mode 100644 index 0000000000000000000000000000000000000000..53267436b6583e9e213f937285faf0dcbfea295d --- /dev/null +++ b/modules/video_unet_temporal/motion_module.py @@ -0,0 +1,352 @@ +# from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import numpy as np +import torch.nn.functional as F +from torch import nn + +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import Attention, FeedForward + +from einops import rearrange, repeat +import math + + +def zero_module(module): + # Zero out the parameters of a module and return it. + for p in module.parameters(): + p.detach().zero_() + return module + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +def get_motion_module( + in_channels, + motion_module_type: str, + motion_module_kwargs: dict +): + if motion_module_type == "Vanilla": + return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs,) + else: + raise ValueError + + +class VanillaTemporalModule(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads = 8, + num_transformer_block = 2, + attention_block_types =( "Temporal_Self", "Temporal_Self" ), + cross_frame_attention_mode = None, + temporal_position_encoding = True, + temporal_position_encoding_max_len = 24, + temporal_attention_dim_div = 1, + zero_initialize = True, + ): + super().__init__() + + self.temporal_transformer = TemporalTransformer3DModel( + in_channels=in_channels, + num_attention_heads=num_attention_heads, + attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div, + num_layers=num_transformer_block, + attention_block_types=attention_block_types, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + + if zero_initialize: + self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out) + + def forward(self, input_tensor, temb, encoder_hidden_states=None, attention_mask=None, anchor_frame_idx=None, video_start_index=0): + hidden_states = input_tensor + hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask, video_start_index=video_start_index) + + output = hidden_states + return output + + +class TemporalTransformer3DModel(nn.Module): + def __init__( + self, + in_channels, + num_attention_heads, + attention_head_dim, + + num_layers, + attention_block_types = ( "Temporal_Self", "Temporal_Self", ), + dropout = 0.0, + norm_num_groups = 32, + cross_attention_dim = 768, + activation_fn = "geglu", + attention_bias = False, + upcast_attention = False, + + cross_frame_attention_mode = None, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 24, + ): + super().__init__() + + inner_dim = num_attention_heads * attention_head_dim + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + self.proj_in = nn.Linear(in_channels, inner_dim) + + self.transformer_blocks = nn.ModuleList( + [ + TemporalTransformerBlock( + dim=inner_dim, + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + attention_block_types=attention_block_types, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + attention_bias=attention_bias, + upcast_attention=upcast_attention, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + for d in range(num_layers) + ] + ) + self.proj_out = nn.Linear(inner_dim, in_channels) + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_start_index=0): + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + inner_dim = hidden_states.shape[1] # c + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) # 这不还是在像素对像素的通道维做attn吗 + + # Transformer Blocks + for block in self.transformer_blocks: + hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length, video_start_index=video_start_index) + + # output + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + + return output + + +class TemporalTransformerBlock(nn.Module): + def __init__( + self, + dim, + num_attention_heads, + attention_head_dim, + attention_block_types = ( "Temporal_Self", "Temporal_Self", ), + dropout = 0.0, + norm_num_groups = 32, + cross_attention_dim = 768, + activation_fn = "geglu", + attention_bias = False, + upcast_attention = False, + cross_frame_attention_mode = None, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 24, + ): + super().__init__() + + attention_blocks = [] + norms = [] + + for block_name in attention_block_types: + attention_blocks.append( + VersatileAttention( + attention_mode=block_name.split("_")[0], + cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None, + + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_position_encoding=temporal_position_encoding, + temporal_position_encoding_max_len=temporal_position_encoding_max_len, + ) + ) + norms.append(nn.LayerNorm(dim)) + + self.attention_blocks = nn.ModuleList(attention_blocks) + self.norms = nn.ModuleList(norms) + + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.ff_norm = nn.LayerNorm(dim) + + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, video_start_index=0): + for attention_block, norm in zip(self.attention_blocks, self.norms): + norm_hidden_states = norm(hidden_states) + hidden_states = attention_block( + norm_hidden_states, + encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None, + video_length=video_length, + video_start_index=video_start_index, + ) + hidden_states + + hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states + + output = hidden_states + return output + + +class PositionalEncoding(nn.Module): + def __init__( + self, + d_model, + dropout = 0., + max_len = 24 + ): + super().__init__() + self.dropout = nn.Dropout(p=dropout) + position = torch.arange(max_len).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)) + pe = torch.zeros(1, max_len, d_model) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer('pe', pe) + + def forward(self, x, start_index): + if start_index + x.size(1) > self.pe.size(1): + start_index = start_index - self.pe.size(1) + if start_index < 0: + raise ValueError(f"start_index must be non-negative, but got {start_index}") + x = x + self.pe[:, start_index: start_index+x.size(1)] + return self.dropout(x) + + +class VersatileAttention(Attention): + use_memory_efficient_attention_xformers: bool = True + def __init__( + self, + attention_mode = None, + cross_frame_attention_mode = None, + temporal_position_encoding = False, + temporal_position_encoding_max_len = 24, + *args, **kwargs + ): + super().__init__(*args, **kwargs) + assert attention_mode == "Temporal" + + self.attention_mode = attention_mode + self.is_cross_attention = kwargs["cross_attention_dim"] is not None + + self.pos_encoder = PositionalEncoding( + kwargs["query_dim"], + dropout=0., + max_len=temporal_position_encoding_max_len + ) if (temporal_position_encoding and attention_mode == "Temporal") else None + + def extra_repr(self): + return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}" + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, **kwargs): + batch_size, sequence_length, _ = hidden_states.shape + + if self.attention_mode == "Temporal": + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) # 这边做了一个reshape转换, + # 终于是对帧与帧(bd, f, c)做attn了 + + if self.pos_encoder is not None: + hidden_states = self.pos_encoder(hidden_states, kwargs['video_start_index']) + + encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states + else: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.head_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + key = self.head_to_batch_dim(key) + value = self.head_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if hasattr(F, 'scaled_dot_product_attention'): + hidden_states = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + ) + elif self.use_memory_efficient_attention_xformers: + import xformers.ops as xops + hidden_states = xops.memory_efficient_attention(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + hidden_states = self.qkv_attention(query, key, value, attention_mask) + + hidden_states = self.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + + if self.attention_mode == "Temporal": + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states + + def qkv_attention(self, q, k, v, mask=None): + sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale + + if mask is not None: + mask = rearrange(mask, 'b ... -> b (...)') + max_neg_value = -torch.finfo(sim.dtype).max + mask = repeat(mask, 'b j -> (b h) () j', h=self.heads) + sim.masked_fill_(~mask.to(torch.bool), max_neg_value) + + # attention, what we cannot get enough of + attn = sim.softmax(dim=-1) + + out = torch.einsum('b i j, b j d -> b i d', attn, v) + return out \ No newline at end of file diff --git a/modules/video_unet_temporal/resnet.py b/modules/video_unet_temporal/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..72c77246c594bbb13dcb54f8e51cf6ba43044938 --- /dev/null +++ b/modules/video_unet_temporal/resnet.py @@ -0,0 +1,209 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + + +class InflatedConv3d(nn.Conv2d): # 这边是卷积的Inflate操作 + def forward(self, x): + video_length = x.shape[2] + # 关于这里: f在第三维, 在整个输入unet之前就完全变成这个维度模式了 + x = rearrange(x, "b c f h w -> (b f) c h w") # 这里还是蛮奇怪的,并不是所谓的2+1d, 而是直接rearrange输入 + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class Upsample3D(nn.Module): + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + raise NotImplementedError + elif use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) + + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + raise NotImplementedError + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class Downsample3D(nn.Module): + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + raise NotImplementedError + + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + raise NotImplementedError + + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class Mish(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) \ No newline at end of file diff --git a/modules/video_unet_temporal/unet.py b/modules/video_unet_temporal/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..96247b8e3ff93ecacd62e8053e0c0c86979dc3f5 --- /dev/null +++ b/modules/video_unet_temporal/unet.py @@ -0,0 +1,477 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py + +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import os +import json + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from .unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, +) +from .resnet import InflatedConv3d + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + # Additional + use_motion_module = True, + motion_module_resolutions = ( 1,2,4,8 ), + motion_module_mid_block = True, + motion_module_decoder_only = False, + motion_module_type = 'Vanilla', + motion_module_kwargs = {}, + time_cond_proj_dim = None + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + self.use_motion_module = use_motion_module + # input + self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, cond_proj_dim=time_cond_proj_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + res = 2 ** i + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_motion_module=use_motion_module and (res in motion_module_resolutions) and (not motion_module_decoder_only), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_motion_module=use_motion_module and motion_module_mid_block, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the videos + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + res = 2 ** (3 - i) + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_motion_module=use_motion_module and (res in motion_module_resolutions), + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + video_start_index: int = 0, + timestep_cond: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # # center input if necessary + # if self.config.center_input_sample: + # sample = 2 * sample - 1.0 + + # time + timesteps = timestep + if not torch.is_tensor(timesteps): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + # timesteps = timesteps.expand(sample.shape[0]) #! change + timesteps = timesteps.repeat(sample.shape[0] // timesteps.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb, timestep_cond) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + # pre-process + sample = self.conv_in(sample) + + # down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + video_start_index=video_start_index, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, video_start_index=video_start_index) + + down_block_res_samples += res_samples + + # mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, video_start_index=video_start_index + ) + + # up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + video_start_index=video_start_index, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size, video_start_index=video_start_index + ) + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) + + # @classmethod + # def from_pretrained_2d(cls, pretrained_model_path, subfolder=None): + # if subfolder is not None: + # pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + + # config_file = os.path.join(pretrained_model_path, 'config.json') + # if not os.path.isfile(config_file): + # raise RuntimeError(f"{config_file} does not exist") + # with open(config_file, "r") as f: + # config = json.load(f) + # config["_class_name"] = cls.__name__ + # config["down_block_types"] = [ + # "CrossAttnDownBlock3D", + # "CrossAttnDownBlock3D", + # "CrossAttnDownBlock3D", + # "DownBlock3D" + # ] + # config["up_block_types"] = [ + # "UpBlock3D", + # "CrossAttnUpBlock3D", + # "CrossAttnUpBlock3D", + # "CrossAttnUpBlock3D" + # ] + + # from diffusers.utils import WEIGHTS_NAME + # model = cls.from_config(config) + # model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + # if not os.path.isfile(model_file): + # raise RuntimeError(f"{model_file} does not exist") + # state_dict = torch.load(model_file, map_location="cpu") + # for k, v in model.state_dict().items(): + # if '_temp.' in k: + # state_dict.update({k: v}) + # model.load_state_dict(state_dict) + + # return model \ No newline at end of file diff --git a/modules/video_unet_temporal/unet_blocks.py b/modules/video_unet_temporal/unet_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..c4bb13a2508e5b33160a4adf43b122f4749fe774 --- /dev/null +++ b/modules/video_unet_temporal/unet_blocks.py @@ -0,0 +1,680 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py + +import torch +from torch import nn + +from .attention import Transformer3DModel +from .resnet import Downsample3D, ResnetBlock3D, Upsample3D +from .motion_module import get_motion_module + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, +): + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_motion_module=use_motion_module, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + motion_modules = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=in_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, video_start_index=0): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states, video_start_index=video_start_index) if motion_module is not None else hidden_states + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + self.attentions = nn.ModuleList(attentions) # spatial attention + self.resnets = nn.ModuleList(resnets) # downsample + self.motion_modules = nn.ModuleList(motion_modules) # motion layer + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, video_start_index=0): + output_states = () + + for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states, attention_mask, None, video_start_index) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states, video_start_index=video_start_index) if motion_module is not None else hidden_states + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, video_start_index=0): + output_states = () + + for resnet, motion_module in zip(self.resnets, self.motion_modules): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, None, None, None, video_start_index) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = motion_module(hidden_states, temb, video_start_index=video_start_index) if motion_module is not None else hidden_states + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + attentions = [] + motion_modules = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + video_start_index=0, + ): + for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, encoder_hidden_states, attention_mask, None, video_start_index) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states, video_start_index=video_start_index) if motion_module is not None else hidden_states + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + use_motion_module=None, + motion_module_type=None, + motion_module_kwargs=None, + ): + super().__init__() + resnets = [] + motion_modules = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + motion_modules.append( + get_motion_module( + in_channels=out_channels, + motion_module_type=motion_module_type, + motion_module_kwargs=motion_module_kwargs, + ) if use_motion_module else None + ) + + self.resnets = nn.ModuleList(resnets) + self.motion_modules = nn.ModuleList(motion_modules) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, video_start_index=0): + for resnet, motion_module in zip(self.resnets, self.motion_modules): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + if motion_module is not None: + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module), hidden_states.requires_grad_(), temb, None, None, None, video_start_index) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = motion_module(hidden_states, temb, video_start_index=video_start_index) if motion_module is not None else hidden_states + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/modules/vqvae/__pycache__/model.cpython-310.pyc b/modules/vqvae/__pycache__/model.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bda02b3bf39e14b542a3b0fddf1a308922810a17 Binary files /dev/null and b/modules/vqvae/__pycache__/model.cpython-310.pyc differ diff --git a/modules/vqvae/__pycache__/model.cpython-38.pyc b/modules/vqvae/__pycache__/model.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13e90d4d0f8cf7d2ffb209c1059077819f50e385 Binary files /dev/null and b/modules/vqvae/__pycache__/model.cpython-38.pyc differ diff --git a/modules/vqvae/autoencoder.py b/modules/vqvae/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..02e6279e76e97343f3f98bc0f31960281d3cc64e --- /dev/null +++ b/modules/vqvae/autoencoder.py @@ -0,0 +1,283 @@ +import torch +import numpy as np +import pytorch_lightning as pl +import torch.nn.functional as F +from contextlib import contextmanager + +from taming.modules.vqvae.quantize import VectorQuantizer2 as VectorQuantizer + +from .model import Encoder, Decoder + +from misc_utils.model_utils import instantiate_from_config + + +class VQModel(pl.LightningModule): + def __init__(self, + ddconfig, + lossconfig, + n_embed, + embed_dim, + ckpt_path=None, + ignore_keys=[], + image_key="image", + colorize_nlabels=None, + monitor=None, + batch_resize_range=None, + scheduler_config=None, + lr_g_factor=1.0, + remap=None, + sane_index_shape=False, # tell vector quantizer to return indices as bhw + use_ema=False + ): + super().__init__() + self.embed_dim = embed_dim + self.n_embed = n_embed + self.image_key = image_key + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + self.loss = instantiate_from_config(lossconfig) + self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25, + remap=remap, + sane_index_shape=sane_index_shape) + self.quant_conv = torch.nn.Conv2d(ddconfig["z_channels"], embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + if colorize_nlabels is not None: + assert type(colorize_nlabels)==int + self.register_buffer("colorize", torch.randn(3, colorize_nlabels, 1, 1)) + if monitor is not None: + self.monitor = monitor + self.batch_resize_range = batch_resize_range + if self.batch_resize_range is not None: + print(f"{self.__class__.__name__}: Using per-batch resizing in range {batch_resize_range}.") + + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self) + print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + if ckpt_path is not None: + self.init_from_ckpt(ckpt_path, ignore_keys=ignore_keys) + self.scheduler_config = scheduler_config + self.lr_g_factor = lr_g_factor + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.parameters()) + self.model_ema.copy_to(self) + if context is not None: + print(f"{context}: Switched to EMA weights") + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.parameters()) + if context is not None: + print(f"{context}: Restored training weights") + + def init_from_ckpt(self, path, ignore_keys=list()): + sd = torch.load(path, map_location="cpu")["state_dict"] + keys = list(sd.keys()) + for k in keys: + for ik in ignore_keys: + if k.startswith(ik): + print("Deleting key {} from state_dict.".format(k)) + del sd[k] + missing, unexpected = self.load_state_dict(sd, strict=False) + print(f"Restored from {path} with {len(missing)} missing and {len(unexpected)} unexpected keys") + if len(missing) > 0: + print(f"Missing Keys: {missing}") + print(f"Unexpected Keys: {unexpected}") + + def on_train_batch_end(self, *args, **kwargs): + if self.use_ema: + self.model_ema(self) + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + quant, emb_loss, info = self.quantize(h) + return quant, emb_loss, info + + def encode_to_prequant(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, quant): + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + + def decode_code(self, code_b): + quant_b = self.quantize.embed_code(code_b) + dec = self.decode(quant_b) + return dec + + def forward(self, input, return_pred_indices=False): + quant, diff, (_,_,ind) = self.encode(input) + dec = self.decode(quant) + if return_pred_indices: + return dec, diff, ind + return dec, diff + + def get_input(self, batch, k): + x = batch[k] + if len(x.shape) == 3: + x = x[..., None] + x = x.permute(0, 3, 1, 2).to(memory_format=torch.contiguous_format).float() + if self.batch_resize_range is not None: + lower_size = self.batch_resize_range[0] + upper_size = self.batch_resize_range[1] + if self.global_step <= 4: + # do the first few batches with max size to avoid later oom + new_resize = upper_size + else: + new_resize = np.random.choice(np.arange(lower_size, upper_size+16, 16)) + if new_resize != x.shape[2]: + x = F.interpolate(x, size=new_resize, mode="bicubic") + x = x.detach() + return x + + def training_step(self, batch, batch_idx, optimizer_idx): + # https://github.com/pytorch/pytorch/issues/37142 + # try not to fool the heuristics + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + + if optimizer_idx == 0: + # autoencode + aeloss, log_dict_ae = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train", + predicted_indices=ind) + + self.log_dict(log_dict_ae, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return aeloss + + if optimizer_idx == 1: + # discriminator + discloss, log_dict_disc = self.loss(qloss, x, xrec, optimizer_idx, self.global_step, + last_layer=self.get_last_layer(), split="train") + self.log_dict(log_dict_disc, prog_bar=False, logger=True, on_step=True, on_epoch=True) + return discloss + + def validation_step(self, batch, batch_idx): + log_dict = self._validation_step(batch, batch_idx) + with self.ema_scope(): + log_dict_ema = self._validation_step(batch, batch_idx, suffix="_ema") + return log_dict + + def _validation_step(self, batch, batch_idx, suffix=""): + x = self.get_input(batch, self.image_key) + xrec, qloss, ind = self(x, return_pred_indices=True) + aeloss, log_dict_ae = self.loss(qloss, x, xrec, 0, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + + discloss, log_dict_disc = self.loss(qloss, x, xrec, 1, + self.global_step, + last_layer=self.get_last_layer(), + split="val"+suffix, + predicted_indices=ind + ) + rec_loss = log_dict_ae[f"val{suffix}/rec_loss"] + self.log(f"val{suffix}/rec_loss", rec_loss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + self.log(f"val{suffix}/aeloss", aeloss, + prog_bar=True, logger=True, on_step=False, on_epoch=True, sync_dist=True) + if version.parse(pl.__version__) >= version.parse('1.4.0'): + del log_dict_ae[f"val{suffix}/rec_loss"] + self.log_dict(log_dict_ae) + self.log_dict(log_dict_disc) + return self.log_dict + + def configure_optimizers(self): + lr_d = self.learning_rate + lr_g = self.lr_g_factor*self.learning_rate + print("lr_d", lr_d) + print("lr_g", lr_g) + opt_ae = torch.optim.Adam(list(self.encoder.parameters())+ + list(self.decoder.parameters())+ + list(self.quantize.parameters())+ + list(self.quant_conv.parameters())+ + list(self.post_quant_conv.parameters()), + lr=lr_g, betas=(0.5, 0.9)) + opt_disc = torch.optim.Adam(self.loss.discriminator.parameters(), + lr=lr_d, betas=(0.5, 0.9)) + + if self.scheduler_config is not None: + scheduler = instantiate_from_config(self.scheduler_config) + + print("Setting up LambdaLR scheduler...") + scheduler = [ + { + 'scheduler': LambdaLR(opt_ae, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + { + 'scheduler': LambdaLR(opt_disc, lr_lambda=scheduler.schedule), + 'interval': 'step', + 'frequency': 1 + }, + ] + return [opt_ae, opt_disc], scheduler + return [opt_ae, opt_disc], [] + + def get_last_layer(self): + return self.decoder.conv_out.weight + + def log_images(self, batch, only_inputs=False, plot_ema=False, **kwargs): + log = dict() + x = self.get_input(batch, self.image_key) + x = x.to(self.device) + if only_inputs: + log["inputs"] = x + return log + xrec, _ = self(x) + if x.shape[1] > 3: + # colorize with random projection + assert xrec.shape[1] > 3 + x = self.to_rgb(x) + xrec = self.to_rgb(xrec) + log["inputs"] = x + log["reconstructions"] = xrec + if plot_ema: + with self.ema_scope(): + xrec_ema, _ = self(x) + if x.shape[1] > 3: xrec_ema = self.to_rgb(xrec_ema) + log["reconstructions_ema"] = xrec_ema + return log + + def to_rgb(self, x): + assert self.image_key == "segmentation" + if not hasattr(self, "colorize"): + self.register_buffer("colorize", torch.randn(3, x.shape[1], 1, 1).to(x)) + x = F.conv2d(x, weight=self.colorize) + x = 2.*(x-x.min())/(x.max()-x.min()) - 1. + return x + + +class VQModelInterface(VQModel): + def __init__(self, embed_dim, *args, **kwargs): + super().__init__(embed_dim=embed_dim, *args, **kwargs) + self.embed_dim = embed_dim + + def encode(self, x): + h = self.encoder(x) + h = self.quant_conv(h) + return h + + def decode(self, h, force_not_quantize=False): + # also go through quantization layer + if not force_not_quantize: + quant, emb_loss, info = self.quantize(h) + else: + quant = h + quant = self.post_quant_conv(quant) + dec = self.decoder(quant) + return dec + diff --git a/modules/vqvae/model.py b/modules/vqvae/model.py new file mode 100644 index 0000000000000000000000000000000000000000..8812e6536dd121993dfb20d9354bf9a2ae31309c --- /dev/null +++ b/modules/vqvae/model.py @@ -0,0 +1,412 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias = False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) + k = k.softmax(dim=-1) + context = torch.einsum('bhdn,bhen->bhde', k, v) + out = torch.einsum('bhde,bhdn->bhen', context, q) + out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) + return self.to_out(out) + +def nonlinearity(x): + # swish + return x*torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + dtype = x.dtype # handle bf16 problem + x = torch.nn.functional.interpolate(x.float(), scale_factor=2.0, mode="nearest") + x = x.to(dtype) + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=3, + stride=2, + padding=0) + + def forward(self, x): + if self.with_conv: + pad = (0,1,0,1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class ResnetBlock(nn.Module): + def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, + dropout, temb_channels=512): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, + out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d(out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1) + else: + self.nin_shortcut = torch.nn.Conv2d(in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x+h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.k = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.v = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + self.proj_out = torch.nn.Conv2d(in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0) + + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b,c,h,w = q.shape + q = q.reshape(b,c,h*w) + q = q.permute(0,2,1) # b,hw,c + k = k.reshape(b,c,h*w) # b,c,hw + w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c)**(-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b,c,h*w) + w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b,c,h,w) + + h_ = self.proj_out(h_) + + return x+h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f'attn_type {attn_type} unknown' + print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Encoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", + **ignore_kwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + # downsampling + self.conv_in = torch.nn.Conv2d(in_channels, + self.ch, + kernel_size=3, + stride=1, + padding=1) + + curr_res = resolution + in_ch_mult = (1,)+tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch*in_ch_mult[i_level] + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions-1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + 2*z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, x): + # timestep embedding + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions-1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__(self, *, ch, out_ch, ch_mult=(1,2,4,8), num_res_blocks, + attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, + resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, + attn_type="vanilla", **ignorekwargs): + super().__init__() + if use_linear_attn: attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,)+tuple(ch_mult) + block_in = ch*ch_mult[self.num_resolutions-1] + curr_res = resolution // 2**(self.num_resolutions-1) + self.z_shape = (1,z_channels,curr_res,curr_res) + print("Working with z of shape {} = {} dimensions.".format( + self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d(z_channels, + block_in, + kernel_size=3, + stride=1, + padding=1) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock(in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch*ch_mult[i_level] + for i_block in range(self.num_res_blocks+1): + block.append(ResnetBlock(in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout)) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d(block_in, + out_ch, + kernel_size=3, + stride=1, + padding=1) + + def forward(self, z): + #assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks+1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + diff --git a/pl_trainer/__pycache__/diffusion.cpython-310.pyc b/pl_trainer/__pycache__/diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88b6c2f694f5f1a014dc4ca342562c7e897352b2 Binary files /dev/null and b/pl_trainer/__pycache__/diffusion.cpython-310.pyc differ diff --git a/pl_trainer/__pycache__/diffusion.cpython-38.pyc b/pl_trainer/__pycache__/diffusion.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..59d9621293913fd621c3e6a65fa1feef55a362d3 Binary files /dev/null and b/pl_trainer/__pycache__/diffusion.cpython-38.pyc differ diff --git a/pl_trainer/__pycache__/instruct_p2p_video.cpython-310.pyc b/pl_trainer/__pycache__/instruct_p2p_video.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32167321715de64cd7b2101dd265826b07cf8bdb Binary files /dev/null and b/pl_trainer/__pycache__/instruct_p2p_video.cpython-310.pyc differ diff --git a/pl_trainer/__pycache__/instruct_p2p_video.cpython-38.pyc b/pl_trainer/__pycache__/instruct_p2p_video.cpython-38.pyc new file mode 100644 index 0000000000000000000000000000000000000000..69c099f8e3236062fa2f0f165419fd7ae0c0c2be Binary files /dev/null and b/pl_trainer/__pycache__/instruct_p2p_video.cpython-38.pyc differ diff --git a/pl_trainer/diffusion.py b/pl_trainer/diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..8120721edb734eb1dca217f258da57665cb3d64a --- /dev/null +++ b/pl_trainer/diffusion.py @@ -0,0 +1,365 @@ +import torch +from torch import nn +import pytorch_lightning as pl +from misc_utils.model_utils import default, instantiate_from_config +from diffusers import DDPMScheduler + +from safetensors.torch import load_file + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + +class DDPM(pl.LightningModule): + def __init__( + self, + unet, + beta_schedule_args={ + 'beta_start': 0.00085, + 'beta_end': 0.0012, + 'num_train_timesteps': 1000, + 'beta_schedule': 'scaled_linear', + 'clip_sample': False, + 'thresholding': False, + }, + prediction_type='epsilon', + loss_fn='l2', + optim_args={}, + base_path=None, + **kwargs + ): + ''' + denoising_fn: a denoising model such as UNet + beta_schedule_args: a dictionary which contains + the configurations of the beta schedule + ''' + super().__init__(**kwargs) + self.unet = unet + self.prediction_type = prediction_type + beta_schedule_args.update({'prediction_type': prediction_type}) + self.set_beta_schedule(beta_schedule_args) + self.num_timesteps = beta_schedule_args['num_train_timesteps'] + self.optim_args = optim_args + self.loss = loss_fn + self.base_path = base_path + if loss_fn == 'l2' or loss_fn == 'mse': + self.loss_fn = nn.MSELoss(reduction='none') + elif loss_fn == 'l1' or loss_fn == 'mae': + self.loss_fn = nn.L1Loss(reduction='none') + elif isinstance(loss_fn, dict): + self.loss_fn = instantiate_from_config(loss_fn) + else: + raise NotImplementedError + + def set_beta_schedule(self, beta_schedule_args): + self.beta_schedule_args = beta_schedule_args + self.scheduler = DDPMScheduler(**beta_schedule_args) + + @torch.no_grad() + def add_noise(self, x, t, noise=None): + noise = default(noise, torch.randn_like(x)) + return self.scheduler.add_noise(x, noise, t) + + def predict_x_0_from_x_t(self, model_output: torch.Tensor, t: torch.LongTensor, x_t: torch.Tensor): # 这边是一个缓存值: predicted x0 + ''' recover x_0 from predicted noise. Reverse of Eq(4) in DDPM paper + \hat(x_0) = 1 / sqrt[\bar(a)]*x_t - sqrt[(1-\bar(a)) / \bar(a)]*noise''' + # return self.scheduler.step(model_output, int(t), x_t).pred_original_sample + if self.prediction_type == 'sample': + return model_output + # for training target == epsilon + alphas_cumprod = self.scheduler.alphas_cumprod.to(device=x_t.device, dtype=x_t.dtype) + sqrt_recip_alphas_cumprod = torch.sqrt(1. / alphas_cumprod[t]).flatten() + sqrt_recipm1_alphas_cumprod = torch.sqrt(1. / alphas_cumprod[t] - 1.).flatten() + while len(sqrt_recip_alphas_cumprod.shape) < len(x_t.shape): + sqrt_recip_alphas_cumprod = sqrt_recip_alphas_cumprod.unsqueeze(-1) + sqrt_recipm1_alphas_cumprod = sqrt_recipm1_alphas_cumprod.unsqueeze(-1) + return sqrt_recip_alphas_cumprod * x_t - sqrt_recipm1_alphas_cumprod * model_output + + def predict_x_tm1_from_x_t(self, model_output, t, x_t): + '''predict x_{t-1} from x_t and model_output''' + return self.scheduler.step(model_output, t, x_t).prev_sample + +class DDPMTraining(DDPM): # 加入training step保证训练等等 + def __init__( + self, + unet, + beta_schedule_args, + prediction_type='epsilon', + loss_fn='l2', + optim_args={ + 'lr': 1e-3, + 'weight_decay': 5e-4 + }, + log_args={}, # for record all arguments with self.save_hyperparameters + ddim_sampling_steps=20, + guidance_scale=5., + **kwargs + ): + super().__init__( + unet=unet, + beta_schedule_args=beta_schedule_args, + prediction_type=prediction_type, + loss_fn=loss_fn, + optim_args=optim_args, + **kwargs) + self.log_args = log_args + self.call_save_hyperparameters() + + self.ddim_sampling_steps = ddim_sampling_steps + self.guidance_scale = guidance_scale + + def call_save_hyperparameters(self): + '''write in a separate function so that the inherit class can overwrite it''' + self.save_hyperparameters(ignore=['unet']) + + def process_batch(self, x_0, mode): + assert mode in ['train', 'val', 'test'] + b, *_ = x_0.shape + noise = torch.randn_like(x_0) + if mode == 'train': + t = torch.randint(0, self.num_timesteps, (b,), device=x_0.device).long() + x_t = self.add_noise(x_0, t, noise=noise) + else: + t = torch.full((b,), self.num_timesteps-1, device=x_0.device, dtype=torch.long) + x_t = self.add_noise(x_0, t, noise=noise) + + model_kwargs = {} + '''the order of return is + 1) model input, + 2) model pred target, + 3) model time condition + 4) raw image before adding noise + 5) model_kwargs + ''' + if self.prediction_type == 'epsilon': + return { + 'model_input': x_t, + 'model_target': noise, + 't': t, + 'model_kwargs': model_kwargs + } + else: + return { + 'model_input': x_t, + 'model_target': x_0, + 't': t, + 'model_kwargs': model_kwargs + } + + def forward(self, x): + return self.validation_step(x, 0) + + def get_loss(self, pred, target, t): + loss_raw = self.loss_fn(pred, target) + loss_flat = mean_flat(loss_raw) + + loss = loss_flat + loss = loss.mean() + + return loss + + def get_hdr_loss(self, fg_mask, pred, pred_combine): # fg_mask: 1,16,4,64,64 都是这个维度 + # import pdb; pdb.set_trace() #todo 打印维度, 查看是否有问题 + loss_raw = self.loss_fn(pred, pred_combine) #(1,16,4,64,64) + masked_loss = fg_mask * loss_raw + loss_flat = mean_flat(masked_loss) + + loss = loss_flat + loss = loss.mean() + + return loss + + def training_step(self, batch, batch_idx): + self.clip_denoised = False + processed_batch = self.process_batch(batch, mode='train') + x_t = processed_batch['model_input'] + y = processed_batch['model_target'] + t = processed_batch['t'] + model_kwargs = processed_batch['model_kwargs'] + pred = self.unet(x_t, t, **model_kwargs) + loss = self.get_loss(pred, y, t) + x_0_hat = self.predict_x_0_from_x_t(pred, t, x_t) + + self.log(f'train_loss', loss) + return { + 'loss': loss, + 'model_input': x_t, + 'model_output': pred, + 'x_0_hat': x_0_hat + } + + @torch.no_grad() + def validation_step(self, batch, batch_idx): + from diffusers import DDIMScheduler + scheduler = DDIMScheduler(**self.beta_schedule_args) + scheduler.set_timesteps(self.ddim_sampling_steps) + processed_batch = self.process_batch(batch, mode='val') + x_t = torch.randn_like(processed_batch['model_input']) + x_hist = [] + timesteps = scheduler.timesteps + for i, t in enumerate(timesteps): + t_ = torch.full((x_t.shape[0],), t, device=x_t.device, dtype=torch.long) + model_output = self.unet(x_t, t_, **processed_batch['model_kwargs']) + x_hist.append( + self.predict_x_0_from_x_t(model_output, t_, x_t) + ) + x_t = scheduler.step(model_output, t, x_t).prev_sample + + return { + 'x_pred': x_t, + 'x_hist': torch.stack(x_hist, dim=1), + } + + def test_step(self, batch, batch_idx): + '''Test is usually not used in a sampling problem''' + return self.validation_step(batch, batch_idx) + + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters(), **self.optim_args) + return optimizer + +class DDPMLDMTraining(DDPMTraining): # 加入潜变量, LDM 即在latent层面上来做 + def __init__( + self, *args, + vae, + unet_init_weights=None, + vae_init_weights=None, + scale_factor=0.18215, + **kwargs + ): + super().__init__(*args, **kwargs) + self.vae = vae + self.scale_factor = scale_factor + self.initialize_unet(unet_init_weights) + self.initialize_vqvae(vae_init_weights) # 这边可以把这个设为none(config文件里面) + + def initialize_unet(self, unet_init_weights): + if unet_init_weights is not None: + print(f'INFO: initialize denoising UNet from {unet_init_weights}') + sd = torch.load(unet_init_weights, map_location='cpu') + self.unet.load_state_dict(sd) + + def initialize_vqvae(self, vqvae_init_weights): # 这边vae load最后调用就是这个init函数 + if vqvae_init_weights is not None: + print(f'INFO: initialize VQVAE from {vqvae_init_weights}') + if '.safetensors' in vqvae_init_weights: + sd = load_file(vqvae_init_weights) + else: + sd = torch.load(vqvae_init_weights, map_location='cpu') + self.vae.load_state_dict(sd) + for param in self.vae.parameters(): + param.requires_grad = False # vae 也是冻住参数的 + + def call_save_hyperparameters(self): + '''write in a separate function so that the inherit class can overwrite it''' + self.save_hyperparameters(ignore=['unet', 'vae']) + + @torch.no_grad() + def encode_image_to_latent(self, x): + #return self.vae.encode(x) * self.scale_factor #! change + return self.vae.encode(x).latent_dist.mean * self.scale_factor + + @torch.no_grad() + def decode_latent_to_image(self, x): + x = x / self.scale_factor # 注意一下这个东西出现 必须要一致 sample乘以了, 这边就得除以 + return self.vae.decode(x) + + def process_batch(self, x_0, mode): + x_0 = self.encode_image_to_latent(x_0) + res = super().process_batch(x_0, mode) + return res + + def training_step(self, batch, batch_idx): + res_dict = super().training_step(batch, batch_idx) + res_dict['x_0_hat'] = self.decode_latent_to_image(res_dict['x_0_hat']) + return res_dict + +class DDIMLDMTextTraining(DDPMLDMTraining): # 加入text encoder以及文本编码进行条件生成;+改成DDIM 训练 + def __init__( + self, *args, + text_model, + text_model_init_weights=None, + **kwargs + ): + super().__init__( + *args, **kwargs + ) + self.text_model = text_model + self.initialize_text_model(text_model_init_weights) #! 这个也可以不要, 直接设置weights=None + + def initialize_text_model(self, text_model_init_weights): # 这边text model最后调用就是这个init函数 + if text_model_init_weights is not None: + print(f'INFO: initialize text model from {text_model_init_weights}') + sd = torch.load(text_model_init_weights, map_location='cpu') + self.text_model.load_state_dict(sd) + for param in self.text_model.parameters(): + param.requires_grad = False # 这边设置了text model不回传梯度 + + def call_save_hyperparameters(self): + '''write in a separate function so that the inherit class can overwrite it''' + self.save_hyperparameters(ignore=['unet', 'vae', 'text_model']) + + @torch.no_grad() + def encode_text(self, x): + if isinstance(x, tuple): + x = list(x) + return self.text_model.encode(x) + + def process_batch(self, batch, mode): + x_0 = batch['image'] + text = batch['text'] + processed_batch = super().process_batch(x_0, mode) + processed_batch['model_kwargs'].update({ + 'context': {'text': self.encode_text([text])} + }) + return processed_batch + + def sampling(self, image_shape=(1, 4, 64, 64), text='', negative_text=None): + ''' + Usage: + sampled = self.sampling(text='a cat on the tree', negative_text='') + + x = sampled['x_pred'][0].permute(1, 2, 0).detach().cpu().numpy() + x = x / 2 + 0.5 + plt.imshow(x) + + y = sampled['x_hist'][0, 10].permute(1, 2, 0).detach().cpu().numpy() + y = y / 2 + 0.5 + plt.imshow(y) + ''' + from diffusers import DDIMScheduler # ddim训练 + scheduler = DDIMScheduler(**self.beta_schedule_args) + scheduler.set_timesteps(self.ddim_sampling_steps) + x_t = torch.randn(*image_shape, device=self.device) + + do_cfg = self.guidance_scale > 1. and negative_text is not None + + if do_cfg: + context = {'text': self.encode_text([text, negative_text])} + else: + context = {'text': self.encode_text([text])} + x_hist = [] + timesteps = scheduler.timesteps + for i, t in enumerate(timesteps): + if do_cfg: + model_input = torch.cat([x_t]*2) + else: + model_input = x_t + t_ = torch.full((model_input.shape[0],), t, device=x_t.device, dtype=torch.long) + model_output = self.unet(model_input, t_, context) + + if do_cfg: + model_output_positive, model_output_negative = model_output.chunk(2) + model_output = model_output_negative + self.guidance_scale * (model_output_positive - model_output_negative) + x_hist.append( + self.decode_latent_to_image(self.predict_x_0_from_x_t(model_output, t_[:x_t.shape[0]], x_t)) + ) + x_t = scheduler.step(model_output, t, x_t).prev_sample + + return { + 'x_pred': self.decode_latent_to_image(x_t), + 'x_hist': torch.stack(x_hist, dim=1), + } diff --git a/pl_trainer/inference/__pycache__/inference.cpython-310.pyc b/pl_trainer/inference/__pycache__/inference.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c04bdd00789286f6329282a10fae9efc74c9a31a Binary files /dev/null and b/pl_trainer/inference/__pycache__/inference.cpython-310.pyc differ diff --git a/pl_trainer/inference/inference.py b/pl_trainer/inference/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..26721e7bd5f7b8244e3085c6ef5b918333b76fe5 --- /dev/null +++ b/pl_trainer/inference/inference.py @@ -0,0 +1,678 @@ +import torch +import numpy as np +from typing import Optional, Union, Tuple, List, Callable, Dict +from tqdm import tqdm +import torch +from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler +import torch.nn.functional as nnf +import numpy as np +from einops import rearrange +from misc_utils.flow_utils import warp_image, RAFTFlow, resize_flow +from functools import partial + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + +class Inference(): + def __init__( + self, + unet, + scheduler='ddim', + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", + num_ddim_steps=20, guidance_scale=5, + ): + self.unet = unet + if scheduler == 'ddim': + scheduler_cls = DDIMScheduler + scheduler_kwargs = {'set_alpha_to_one': False, 'steps_offset': 1, 'clip_sample': False} + elif scheduler == 'ddpm': + scheduler_cls = DDPMScheduler + scheduler_kwargs = {'clip_sample': False} + else: + raise NotImplementedError() + self.scheduler = scheduler_cls( + beta_start = beta_start, + beta_end = beta_end, + beta_schedule = beta_schedule, + **scheduler_kwargs + ) + self.scheduler.set_timesteps(num_ddim_steps) + self.num_ddim_steps = num_ddim_steps + self.guidance_scale = guidance_scale + + @torch.no_grad() + def __call__( + self, + latent: torch.Tensor, + context: torch.Tensor, + uncond_context: torch.Tensor=None, + start_time: int = 0, + null_embedding: List[torch.Tensor]=None, + context_kwargs={}, + model_kwargs={}, + ): + all_latent = [] + all_pred = [] # x0_hat + do_classifier_free_guidance = self.guidance_scale > 1 and ((uncond_context is not None) or (null_embedding is not None)) + for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): + t = int(t) + if do_classifier_free_guidance: + latent_input = torch.cat([latent, latent], dim=0) + if null_embedding is not None: + context_input = torch.cat([null_embedding[i], context], dim=0) + else: + context_input = torch.cat([uncond_context, context], dim=0) + else: + latent_input = latent + context_input = context + noise_pred = self.unet( + latent_input, + torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), + context={ 'text': context_input, **context_kwargs}, + **model_kwargs + ) + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + pred_samples = self.scheduler.step(noise_pred, t, latent) + latent = pred_samples.prev_sample + pred = pred_samples.pred_original_sample + all_latent.append(latent.detach()) + all_pred.append(pred.detach()) + + return { + 'latent': latent, + 'all_latent': all_latent, + 'all_pred': all_pred + } + +class InferenceIP2PEditRef(Inference): + def zeros(self, x): + return torch.zeros_like(x) + @torch.no_grad() + def __call__( + self, + latent: torch.Tensor, + text_cond: torch.Tensor, + text_uncond: torch.Tensor, + img_cond: torch.Tensor, + edit_cond: torch.Tensor, + text_cfg = 7.5, + img_cfg = 1.2, + edit_cfg = 1.2, + start_time: int = 0, + ): + ''' + latent1 | latent2 | latent3 | latent4 + text x x x v + edit x x v v + img x v v v + ''' + all_latent = [] + all_pred = [] # x0_hat + for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): + t = int(t) + latent1 = torch.cat([latent, self.zeros(img_cond), self.zeros(edit_cond)], dim=1) + latent2 = torch.cat([latent, img_cond, self.zeros(edit_cond)], dim=1) + latent3 = torch.cat([latent, img_cond, edit_cond], dim=1) + latent4 = latent3.clone() + latent_input = torch.cat([latent1, latent2, latent3, latent4], dim=0) + context_input = torch.cat([text_uncond, text_uncond, text_uncond, text_cond], dim=0) + noise_pred = self.unet( + latent_input, + torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), + context={ 'text': context_input}, + ) + + noise_pred1, noise_pred2, noise_pred3, noise_pred4 = noise_pred.chunk(4, dim=0) + noise_pred = ( + noise_pred1 + + img_cfg * (noise_pred2 - noise_pred1) + + edit_cfg * (noise_pred3 - noise_pred2) + + text_cfg * (noise_pred4 - noise_pred3) + ) # when edit_cfg == img_cfg, noise_pred2 is not used + + pred_samples = self.scheduler.step(noise_pred, t, latent) + latent = pred_samples.prev_sample + pred = pred_samples.pred_original_sample + all_latent.append(latent.detach()) + all_pred.append(pred.detach()) + + return { + 'latent': latent, + 'all_latent': all_latent, + 'all_pred': all_pred + } + +class InferenceIP2PVideo(Inference): + def zeros(self, x): + return torch.zeros_like(x) + @torch.no_grad() + def __call__( + self, + latent: torch.Tensor, + text_cond: torch.Tensor, + text_uncond: torch.Tensor, + img_cond: torch.Tensor, + text_cfg = 7.5, + img_cfg = 1.2, + start_time: int = 0, + guidance_rescale: float = 0.0, + ): + ''' + latent1 | latent2 | latent3 + text x x v + img x v v + ''' + all_latent = [] + all_pred = [] # x0_hat + for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): + t = int(t) + latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2) + latent2 = torch.cat([latent, img_cond], dim=2) + latent3 = latent2.clone() + latent_input = torch.cat([latent1, latent2, latent3], dim=0) + context_input = torch.cat([text_uncond, text_uncond, text_cond], dim=0) + + latent_input = rearrange(latent_input, 'b f c h w -> b c f h w') + noise_pred = self.unet( + latent_input, + torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), + encoder_hidden_states=context_input, + ).sample + noise_pred = rearrange(noise_pred, 'b c f h w -> b f c h w') + + + noise_pred1, noise_pred2, noise_pred3 = noise_pred.chunk(3, dim=0) + noise_pred = ( + noise_pred1 + + img_cfg * (noise_pred2 - noise_pred1) + + text_cfg * (noise_pred3 - noise_pred2) + ) + + if guidance_rescale > 0: + noise_pred = rescale_noise_cfg(noise_pred, noise_pred1, guidance_rescale=guidance_rescale) + + pred_samples = self.scheduler.step(noise_pred, t, latent) + latent = pred_samples.prev_sample + pred = pred_samples.pred_original_sample + all_latent.append(latent.detach()) + all_pred.append(pred.detach()) + + return { + 'latent': latent, + 'all_latent': all_latent, + 'all_pred': all_pred + } + + @torch.no_grad() + def second_clip_forward( + self, + latent: torch.Tensor, + text_cond: torch.Tensor, + text_uncond: torch.Tensor, + img_cond: torch.Tensor, + latent_ref: torch.Tensor, + noise_correct_step: float = 1., + text_cfg = 7.5, + img_cfg = 1.2, + start_time: int = 0, + guidance_rescale: float = 0.0, + ): + ''' + latent1 | latent2 | latent3 + text x x v + img x v v + ''' + num_ref_frames = latent_ref.shape[1] + all_latent = [] + all_pred = [] # x0_hat + for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): + t = int(t) + + latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2) + latent2 = torch.cat([latent, img_cond], dim=2) + latent3 = latent2.clone() + latent_input = torch.cat([latent1, latent2, latent3], dim=0) + context_input = torch.cat([text_uncond, text_uncond, text_cond], dim=0) + + latent_input = rearrange(latent_input, 'b f c h w -> b c f h w') + noise_pred = self.unet( + latent_input, + torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), + encoder_hidden_states=context_input, + ).sample + noise_pred = rearrange(noise_pred, 'b c f h w -> b f c h w') + + noise_pred1, noise_pred2, noise_pred3 = noise_pred.chunk(3, dim=0) + noise_pred = ( + noise_pred1 + + img_cfg * (noise_pred2 - noise_pred1) + + text_cfg * (noise_pred3 - noise_pred2) + ) + + if guidance_rescale > 0: + noise_pred = rescale_noise_cfg(noise_pred, noise_pred1, guidance_rescale=guidance_rescale) + + # 所谓的再inference阶段加入 Long Video Sampling Correction(LVSC) + if noise_correct_step * self.num_ddim_steps > i: + alpha_prod_t = self.scheduler.alphas_cumprod[t] + beta_prod_t = 1 - alpha_prod_t + noise_ref = (latent[:, 0:num_ref_frames] - (alpha_prod_t ** 0.5) * latent_ref) / (beta_prod_t ** 0.5) # b 1 c h w + delta_noise_ref = noise_ref - noise_pred[:, 0:num_ref_frames] + delta_noise_remaining = delta_noise_ref.mean(dim=1, keepdim=True) + noise_pred[:, :num_ref_frames] = noise_pred[:, :num_ref_frames] + delta_noise_ref + noise_pred[:, num_ref_frames:] = noise_pred[:, num_ref_frames:] + delta_noise_remaining + + pred_samples = self.scheduler.step(noise_pred, t, latent) + latent = pred_samples.prev_sample + pred = pred_samples.pred_original_sample + all_latent.append(latent.detach()) + all_pred.append(pred.detach()) + + return { + 'latent': latent, + 'all_latent': all_latent, + 'all_pred': all_pred + } + +class InferenceIP2PVideoEnsemble(Inference): + def zeros(self, x): + return torch.zeros_like(x) + @torch.no_grad() + def __call__( + self, + latent: torch.Tensor, + text_cond: torch.Tensor, + text_uncond: torch.Tensor, + img_cond: torch.Tensor, + text_cfg = 7.5, + img_cfg = 1.2, + start_time: int = 0, + guidance_rescale: float = 0.0, + ): + ''' + latent1 | latent2 | latent3 + text x x v + img x v v + ''' + all_latent = [] + all_pred = [] # x0_hat + for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): + t = int(t) + + latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2) + latent2 = torch.cat([latent, img_cond], dim=2) + latent3 = latent2.clone() + latent_input = torch.cat([latent1, latent2, latent3], dim=0) + context_input = torch.cat([text_uncond, text_uncond, text_cond], dim=0) + + latent_input = rearrange(latent_input, 'b f c h w -> b c f h w') + noise_pred = self.unet( + latent_input, + torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), + encoder_hidden_states=context_input, + ).sample + noise_pred = rearrange(noise_pred, 'b c f h w -> b f c h w') + + + noise_pred1, noise_pred2, noise_pred3 = noise_pred.chunk(3, dim=0) + noise_pred = ( + noise_pred1 + + img_cfg * (noise_pred2 - noise_pred1) + + text_cfg * (noise_pred3 - noise_pred2) + ) + + if guidance_rescale > 0: + noise_pred = rescale_noise_cfg(noise_pred, noise_pred1, guidance_rescale=guidance_rescale) + + pred_samples = self.scheduler.step(noise_pred, t, latent) + latent = pred_samples.prev_sample + # average over all three samples. + latent = latent.mean(dim=0, keepdim=True).repeat(latent.shape[0], 1, 1, 1, 1) + # latent = latent[[0]].repeat(latent.shape[0], 1, 1, 1, 1) + pred = pred_samples.pred_original_sample + all_latent.append(latent.detach()) + all_pred.append(pred.detach()) + + return { + 'latent': latent, + 'all_latent': all_latent, + 'all_pred': all_pred + } + + @torch.no_grad() + def second_clip_forward( + self, + latent: torch.Tensor, + text_cond: torch.Tensor, + text_uncond: torch.Tensor, + img_cond: torch.Tensor, + latent_ref: torch.Tensor, + noise_correct_step: float = 1., + text_cfg = 7.5, + img_cfg = 1.2, + start_time: int = 0, + guidance_rescale: float = 0.0, + ): + ''' + latent1 | latent2 | latent3 + text x x v + img x v v + ''' + num_ref_frames = latent_ref.shape[1] + all_latent = [] + all_pred = [] # x0_hat + for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): + t = int(t) + + latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2) + latent2 = torch.cat([latent, img_cond], dim=2) + latent3 = latent2.clone() + latent_input = torch.cat([latent1, latent2, latent3], dim=0) + context_input = torch.cat([text_uncond, text_uncond, text_cond], dim=0) + + latent_input = rearrange(latent_input, 'b f c h w -> b c f h w') + noise_pred = self.unet( + latent_input, + torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), + encoder_hidden_states=context_input, + ).sample + noise_pred = rearrange(noise_pred, 'b c f h w -> b f c h w') + + noise_pred1, noise_pred2, noise_pred3 = noise_pred.chunk(3, dim=0) + noise_pred = ( + noise_pred1 + + img_cfg * (noise_pred2 - noise_pred1) + + text_cfg * (noise_pred3 - noise_pred2) + ) + + if guidance_rescale > 0: + noise_pred = rescale_noise_cfg(noise_pred, noise_pred1, guidance_rescale=guidance_rescale) + + # 所谓的再inference阶段加入 Long Video Sampling Correction(LVSC) + if noise_correct_step * self.num_ddim_steps > i: + alpha_prod_t = self.scheduler.alphas_cumprod[t] + beta_prod_t = 1 - alpha_prod_t + noise_ref = (latent[:, 0:num_ref_frames] - (alpha_prod_t ** 0.5) * latent_ref) / (beta_prod_t ** 0.5) # b 1 c h w + delta_noise_ref = noise_ref - noise_pred[:, 0:num_ref_frames] + delta_noise_remaining = delta_noise_ref.mean(dim=1, keepdim=True) + noise_pred[:, :num_ref_frames] = noise_pred[:, :num_ref_frames] + delta_noise_ref + noise_pred[:, num_ref_frames:] = noise_pred[:, num_ref_frames:] + delta_noise_remaining + + pred_samples = self.scheduler.step(noise_pred, t, latent) + latent = pred_samples.prev_sample + pred = pred_samples.pred_original_sample + all_latent.append(latent.detach()) + all_pred.append(pred.detach()) + + return { + 'latent': latent, + 'all_latent': all_latent, + 'all_pred': all_pred + } + + +class InferenceIP2PVideoHDR(Inference): + def zeros(self, x): + return torch.zeros_like(x) + @torch.no_grad() + def __call__( + self, + latent: torch.Tensor, + text_cond: torch.Tensor, + text_uncond: torch.Tensor,#(1,77,768) + hdr_cond: torch.Tensor, #(1,3,768) + img_cond: torch.Tensor, + text_cfg = 7.5, + img_cfg = 1.2, + hdr_cfg = 7.5, + start_time: int = 0, + guidance_rescale: float = 0.0, + ): + ''' + latent1 | latent2 | latent3 | latent4 + text x x v v + img x v v v + hdr x x x v + ''' + all_latent = [] + all_pred = [] # x0_hat + for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): + t = int(t) + latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2) + latent2 = torch.cat([latent, img_cond], dim=2) + latent3 = latent2.clone() + latent4 = latent2.clone() + latent_input = torch.cat([latent1, latent2, latent3, latent4], dim=0) + context_input = torch.cat([text_uncond, text_uncond, text_cond, text_cond], dim=0) #(4,77,768) + + hdr_uncond = self.zeros(hdr_cond) + hdr_input = torch.cat([hdr_uncond, hdr_uncond, hdr_uncond, hdr_cond]) #(4,3,768) + + model_kwargs1 = {'hdr_latents': hdr_input, 'encoder_hidden_states': context_input} + + latent_input = rearrange(latent_input, 'b f c h w -> b c f h w') + noise_pred = self.unet( + latent_input, + torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), + encoder_hidden_states=model_kwargs1, + ).sample + noise_pred = rearrange(noise_pred, 'b c f h w -> b f c h w') + + + noise_pred1, noise_pred2, noise_pred3, noise_pred4 = noise_pred.chunk(4, dim=0) + noise_pred = ( + noise_pred1 + + img_cfg * (noise_pred2 - noise_pred1) + + text_cfg * (noise_pred3 - noise_pred2) + + hdr_cfg * (noise_pred4 - noise_pred3) + ) + + if guidance_rescale > 0: + noise_pred = rescale_noise_cfg(noise_pred, noise_pred1, guidance_rescale=guidance_rescale) + + pred_samples = self.scheduler.step(noise_pred, t, latent) + latent = pred_samples.prev_sample + pred = pred_samples.pred_original_sample + all_latent.append(latent.detach()) + all_pred.append(pred.detach()) + + return { + 'latent': latent, + 'all_latent': all_latent, + 'all_pred': all_pred + } + + @torch.no_grad() + def second_clip_forward( + self, + latent: torch.Tensor, + text_cond: torch.Tensor, + text_uncond: torch.Tensor, + img_cond: torch.Tensor, + latent_ref: torch.Tensor, + noise_correct_step: float = 1., + text_cfg = 7.5, + img_cfg = 1.2, + start_time: int = 0, + guidance_rescale: float = 0.0, + ): + ''' + latent1 | latent2 | latent3 + text x x v + img x v v + ''' + num_ref_frames = latent_ref.shape[1] + all_latent = [] + all_pred = [] # x0_hat + for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): + t = int(t) + + latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2) + latent2 = torch.cat([latent, img_cond], dim=2) + latent3 = latent2.clone() + latent_input = torch.cat([latent1, latent2, latent3], dim=0) + context_input = torch.cat([text_uncond, text_uncond, text_cond], dim=0) + + latent_input = rearrange(latent_input, 'b f c h w -> b c f h w') + noise_pred = self.unet( + latent_input, + torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), + encoder_hidden_states=context_input, + ).sample + noise_pred = rearrange(noise_pred, 'b c f h w -> b f c h w') + + noise_pred1, noise_pred2, noise_pred3 = noise_pred.chunk(3, dim=0) + noise_pred = ( + noise_pred1 + + img_cfg * (noise_pred2 - noise_pred1) + + text_cfg * (noise_pred3 - noise_pred2) + ) + + if guidance_rescale > 0: + noise_pred = rescale_noise_cfg(noise_pred, noise_pred1, guidance_rescale=guidance_rescale) + + # 所谓的再inference阶段加入 Long Video Sampling Correction(LVSC) + if noise_correct_step * self.num_ddim_steps > i: + alpha_prod_t = self.scheduler.alphas_cumprod[t] + beta_prod_t = 1 - alpha_prod_t + noise_ref = (latent[:, 0:num_ref_frames] - (alpha_prod_t ** 0.5) * latent_ref) / (beta_prod_t ** 0.5) # b 1 c h w + delta_noise_ref = noise_ref - noise_pred[:, 0:num_ref_frames] + delta_noise_remaining = delta_noise_ref.mean(dim=1, keepdim=True) + noise_pred[:, :num_ref_frames] = noise_pred[:, :num_ref_frames] + delta_noise_ref + noise_pred[:, num_ref_frames:] = noise_pred[:, num_ref_frames:] + delta_noise_remaining + + pred_samples = self.scheduler.step(noise_pred, t, latent) + latent = pred_samples.prev_sample + pred = pred_samples.pred_original_sample + all_latent.append(latent.detach()) + all_pred.append(pred.detach()) + + return { + 'latent': latent, + 'all_latent': all_latent, + 'all_pred': all_pred + } + + +class InferenceIP2PVideoOpticalFlow(InferenceIP2PVideo): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.flow_estimator = RAFTFlow().cuda() # 使用光流估计器 + + def obtain_delta_noise(self, delta_noise_ref, flow): + flow = resize_flow(flow, delta_noise_ref.shape[2:]) + warped_delta_noise_ref = warp_image(delta_noise_ref, flow) # 根据光流扭曲参考帧的噪声差异 + valid_mask = torch.ones_like(delta_noise_ref)[:, :1] + valid_mask = warp_image(valid_mask, flow) + return warped_delta_noise_ref, valid_mask + + def obtain_flow_batched(self, ref_images, query_images): + ref_images = ref_images.to() + warp_funcs = [] + for query_image in query_images: + query_image = query_image.unsqueeze(0).repeat(len(ref_images), 1, 1, 1) + flow = self.flow_estimator(query_image, ref_images) # 估计光流 + warp_func = partial(self.obtain_delta_noise, flow=flow) + warp_funcs.append(warp_func) + return warp_funcs + + @torch.no_grad() + def second_clip_forward( + self, + latent: torch.Tensor, + text_cond: torch.Tensor, + text_uncond: torch.Tensor, + img_cond: torch.Tensor, + latent_ref: torch.Tensor, + ref_images: torch.Tensor, + query_images: torch.Tensor, + noise_correct_step: float = 1., + text_cfg = 7.5, + img_cfg = 1.2, + start_time: int = 0, + guidance_rescale: float = 0.0, + ): + ''' + latent1 | latent2 | latent3 + text x x v + img x v v + ''' + assert ref_images.shape[0] == 1, 'only support batch size 1' + warp_funcs = self.obtain_flow_batched(ref_images[0], query_images[0]) + num_ref_frames = latent_ref.shape[1] + all_latent = [] + all_pred = [] # x0_hat + for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): + t = int(t) + + latent1 = torch.cat([latent, self.zeros(img_cond)], dim=2) + latent2 = torch.cat([latent, img_cond], dim=2) + latent3 = latent2.clone() + latent_input = torch.cat([latent1, latent2, latent3], dim=0) + context_input = torch.cat([text_uncond, text_uncond, text_cond], dim=0) + + latent_input = rearrange(latent_input, 'b f c h w -> b c f h w') + noise_pred = self.unet( + latent_input, + torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), + encoder_hidden_states=context_input, + ).sample + noise_pred = rearrange(noise_pred, 'b c f h w -> b f c h w') + + noise_pred1, noise_pred2, noise_pred3 = noise_pred.chunk(3, dim=0) + noise_pred = ( + noise_pred1 + + img_cfg * (noise_pred2 - noise_pred1) + + text_cfg * (noise_pred3 - noise_pred2) + ) + + if guidance_rescale > 0: + noise_pred = rescale_noise_cfg(noise_pred, noise_pred1, guidance_rescale=guidance_rescale) + + + if noise_correct_step * self.num_ddim_steps > i: + alpha_prod_t = self.scheduler.alphas_cumprod[t] + beta_prod_t = 1 - alpha_prod_t + noise_ref = (latent[:, 0:num_ref_frames] - (alpha_prod_t ** 0.5) * latent_ref) / (beta_prod_t ** 0.5) # b 1 c h w + delta_noise_ref = noise_ref - noise_pred[:, 0:num_ref_frames] + noise_pred[:, :num_ref_frames] = noise_pred[:, :num_ref_frames] + delta_noise_ref + + for refed_index, warp_func in zip(range(num_ref_frames, noise_pred.shape[1]), warp_funcs): + delta_noise_remaining, delta_noise_mask = warp_func(delta_noise_ref[0]) + mask_sum = delta_noise_mask[None].sum(dim=1, keepdim=True) + delta_noise_remaining = torch.where( + mask_sum > 0.5, + delta_noise_remaining[None].sum(dim=1, keepdim=True) / mask_sum, + 0. + ) + noise_pred[:, refed_index: refed_index+1] += torch.where( + mask_sum > 0.5, + delta_noise_remaining, + 0 + ) # 将这个扭曲的噪声差异应用到当前帧,确保帧之间的噪声变化符合视频中物体的移动 + + pred_samples = self.scheduler.step(noise_pred, t, latent) + latent = pred_samples.prev_sample + pred = pred_samples.pred_original_sample + all_latent.append(latent.detach()) + all_pred.append(pred.detach()) + + return { + 'latent': latent, + 'all_latent': all_latent, + 'all_pred': all_pred + } \ No newline at end of file diff --git a/pl_trainer/inference/inference_damo.py b/pl_trainer/inference/inference_damo.py new file mode 100644 index 0000000000000000000000000000000000000000..45581b01489f028e39b2655a0a265f3b0d0255cd --- /dev/null +++ b/pl_trainer/inference/inference_damo.py @@ -0,0 +1,307 @@ +import torch +from typing import List, Union, Tuple +from tqdm import tqdm +from .inference import Inference + +class InferenceDAMO(Inference): + @torch.no_grad() + def __call__( + self, + latent: torch.Tensor, + context: torch.Tensor, + uncond_context: torch.Tensor=None, + start_time: int = 0, + null_embedding: List[torch.Tensor]=None, + ): + all_latent = [] + all_pred = [] # x0_hat + do_classifier_free_guidance = self.guidance_scale > 1 and ((uncond_context is not None) or (null_embedding is not None)) + for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): + t = int(t) + if do_classifier_free_guidance: + latent_input = torch.cat([latent, latent], dim=0) + if null_embedding is not None: + context_input = torch.cat([null_embedding[i], context], dim=0) + else: + context_input = torch.cat([uncond_context, context], dim=0) + else: + latent_input = latent + context_input = context + noise_pred = self.unet( + latent_input, + torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), + context_input, + ) + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + pred_samples = self.scheduler.step(noise_pred, t, latent) + latent = pred_samples.prev_sample + pred = pred_samples.pred_original_sample + all_latent.append(latent.detach()) + all_pred.append(pred.detach()) + + return { + 'latent': latent, + 'all_latent': all_latent, + 'all_pred': all_pred + } + +class InferenceDAMO_PTP(Inference): + def infer_old_context(self, latent, context, t, uncond_context=None): + do_classifier_free_guidance = self.guidance_scale > 1 and (uncond_context is not None) + + if do_classifier_free_guidance: + latent_input = torch.cat([latent, latent], dim=0) + context_input = torch.cat([uncond_context, context], dim=0) + else: + latent_input = latent + context_input = context + + noise_pred = self.unet( + latent_input, + torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), + context_input, + ) + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + pred_samples = self.scheduler.step(noise_pred, t, latent) + latent = pred_samples.prev_sample + pred = pred_samples.pred_original_sample + return latent, pred + + def infer_new_context(self, latent, context, t, uncond_context=None): + do_classifier_free_guidance = self.guidance_scale > 1 and (uncond_context is not None) + + if do_classifier_free_guidance: + latent_input = torch.cat([latent, latent], dim=0) + if isinstance(context, (list, tuple)): + context_input = ( + torch.cat([uncond_context, context[0]], dim=0), + torch.cat([uncond_context, context[1]], dim=0), + ) + else: + context_input = torch.cat([uncond_context, context], dim=0) + else: + latent_input = latent + context_input = context + + noise_pred = self.unet( + latent_input, + torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), + context_input, + ) + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + pred_samples = self.scheduler.step(noise_pred, t, latent) + latent = pred_samples.prev_sample + pred = pred_samples.pred_original_sample + return latent, pred + + @torch.no_grad() + def __call__( + self, + latent: torch.Tensor, + context: torch.Tensor, # used when > ca_end_time + old_context: torch.Tensor=None, # used when < sa_end_time + old_to_new_context: Union[Tuple, List]=None, # used when sa_end_time < t < ca_end_time + uncond_context: torch.Tensor=None, + sa_end_time: float=0.3, + ca_end_time: float=0.8, + start_time: int = 0, + ): + assert sa_end_time < ca_end_time, f"sa_end_time must be less than ca_end_time, got {sa_end_time} and {ca_end_time} respectively" + all_latent = [] + all_pred = [] + all_latent_old = [] + all_pred_old = [] + old_latent = latent.clone() + new_latent = latent.clone() + for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): + t = int(t) + old_latent_next_t, pred_old = self.infer_old_context(old_latent, old_context, t, uncond_context) + if i < sa_end_time * self.num_ddim_steps: + new_latent_next_t, pred_new = old_latent_next_t, pred_old + elif sa_end_time * self.num_ddim_steps <= i < ca_end_time * self.num_ddim_steps: + new_latent_next_t, pred_new = self.infer_new_context( + new_latent, old_to_new_context, t, uncond_context + ) + else: + new_latent_next_t, pred_new = self.infer_new_context( + new_latent, context, t, uncond_context + ) + + old_latent = old_latent_next_t + new_latent = new_latent_next_t + + all_latent.append(new_latent_next_t.detach()) + all_pred.append(pred_new.detach()) + all_latent_old.append(old_latent_next_t.detach()) + all_pred_old.append(pred_old.detach()) + + return { + 'latent': new_latent, + 'latent_old': old_latent, + 'all_latent': all_latent, + 'all_pred': all_pred, + 'all_latent_old': all_latent_old, + 'all_pred_old': all_pred_old, + } + +class InferenceDAMO_PTP_v2(Inference): + def set_ptp_in_xattn_layers(self, prompt_to_prompt: bool, num_frames=1): + for m in self.unet.modules(): + if m.__class__.__name__ == 'CrossAttention': + m.ptp_sa_replace = prompt_to_prompt + m.num_frames = num_frames + + def infer_both_with_sa_replace(self, old_latent, new_latent, old_context, new_context, t, uncond_context=None): + do_classifier_free_guidance = self.guidance_scale > 1 and (uncond_context is not None) + + if do_classifier_free_guidance: + latent_input = torch.cat([old_latent, new_latent, old_latent, new_latent], dim=0) + context_input = torch.cat([uncond_context, uncond_context, old_context, new_context], dim=0) + else: + latent_input = torch.cat([old_latent, new_latent], dim=0) + context_input = torch.cat([old_context, new_context], dim=0) + + noise_pred = self.unet( + latent_input, + torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), + context_input, + ) + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + noise_pred_old, noise_pred_new = noise_pred.chunk(2, dim=0) + pred_samples_old = self.scheduler.step(noise_pred_old, t, old_latent) + pred_samples_new = self.scheduler.step(noise_pred_new, t, new_latent) + + old_latent = pred_samples_old.prev_sample + new_latent = pred_samples_new.prev_sample + old_pred = pred_samples_old.pred_original_sample + new_pred = pred_samples_new.pred_original_sample + + return old_latent, new_latent, old_pred, new_pred + + def infer_old_context(self, latent, context, t, uncond_context=None): + do_classifier_free_guidance = self.guidance_scale > 1 and (uncond_context is not None) + + if do_classifier_free_guidance: + latent_input = torch.cat([latent, latent], dim=0) + context_input = torch.cat([uncond_context, context], dim=0) + else: + latent_input = latent + context_input = context + + noise_pred = self.unet( + latent_input, + torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), + context_input, + ) + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + pred_samples = self.scheduler.step(noise_pred, t, latent) + latent = pred_samples.prev_sample + pred = pred_samples.pred_original_sample + return latent, pred + + def infer_new_context(self, latent, context, t, uncond_context=None): + do_classifier_free_guidance = self.guidance_scale > 1 and (uncond_context is not None) + + if do_classifier_free_guidance: + latent_input = torch.cat([latent, latent], dim=0) + if isinstance(context, (list, tuple)): + context_input = ( + torch.cat([uncond_context, context[0]], dim=0), + torch.cat([uncond_context, context[1]], dim=0), + ) + else: + context_input = torch.cat([uncond_context, context], dim=0) + else: + latent_input = latent + context_input = context + + noise_pred = self.unet( + latent_input, + torch.full((len(latent_input),), t, device=latent_input.device, dtype=torch.long), + context_input, + ) + + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2, dim=0) + noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) + + pred_samples = self.scheduler.step(noise_pred, t, latent) + latent = pred_samples.prev_sample + pred = pred_samples.pred_original_sample + return latent, pred + + @torch.no_grad() + def __call__( + self, + latent: torch.Tensor, + context: torch.Tensor, # used when > ca_end_time + old_context: torch.Tensor=None, # used when < sa_end_time + old_to_new_context: Union[Tuple, List]=None, # used when sa_end_time < t < ca_end_time + uncond_context: torch.Tensor=None, + sa_end_time: float=0.3, + ca_end_time: float=0.8, + start_time: int = 0, + ): + assert sa_end_time < ca_end_time, f"sa_end_time must be less than ca_end_time, got {sa_end_time} and {ca_end_time} respectively" + all_latent = [] + all_pred = [] + all_latent_old = [] + all_pred_old = [] + old_latent = latent.clone() + new_latent = latent.clone() + for i, t in enumerate(tqdm(self.scheduler.timesteps[start_time:])): + t = int(t) + if i < sa_end_time * self.num_ddim_steps: + self.set_ptp_in_xattn_layers(True, num_frames=latent.shape[2]) + old_latent_next_t, new_latent_next_t, pred_old, pred_new = self.infer_both_with_sa_replace( + old_latent, new_latent, old_context, context, t, uncond_context + ) + elif sa_end_time * self.num_ddim_steps <= i < ca_end_time * self.num_ddim_steps: + self.set_ptp_in_xattn_layers(False) + old_latent_next_t, pred_old = self.infer_old_context(old_latent, old_context, t, uncond_context) + new_latent_next_t, pred_new = self.infer_new_context( + new_latent, old_to_new_context, t, uncond_context + ) + else: + self.set_ptp_in_xattn_layers(False) + old_latent_next_t, pred_old = self.infer_old_context(old_latent, old_context, t, uncond_context) + new_latent_next_t, pred_new = self.infer_new_context( + new_latent, context, t, uncond_context + ) + + old_latent = old_latent_next_t + new_latent = new_latent_next_t + + all_latent.append(new_latent_next_t.detach()) + all_pred.append(pred_new.detach()) + all_latent_old.append(old_latent_next_t.detach()) + all_pred_old.append(pred_old.detach()) + + return { + 'latent': new_latent, + 'latent_old': old_latent, + 'all_latent': all_latent, + 'all_pred': all_pred, + 'all_latent_old': all_latent_old, + 'all_pred_old': all_pred_old, + } \ No newline at end of file diff --git a/pl_trainer/instruct_p2p_video.py b/pl_trainer/instruct_p2p_video.py new file mode 100644 index 0000000000000000000000000000000000000000..e3bed30e0f3abe1ea1a766cebfdbea4c746300b7 --- /dev/null +++ b/pl_trainer/instruct_p2p_video.py @@ -0,0 +1,866 @@ +''' +Use pretrained instruct pix2pix model but add additional channels for reference modification +''' + +import torch +from .diffusion import DDIMLDMTextTraining +from einops import rearrange + +from modules.video_unet_temporal.resnet import InflatedConv3d +from safetensors.torch import load_file + +import torch.nn.functional as F + +from torch import nn +import cv2 +from torch.hub import download_url_to_file + +class MLP(nn.Module): + def __init__(self): + super(MLP, self).__init__() + self.fc1 = nn.Linear(3072, 4096) + self.fc2 = nn.Linear(4096, 4096) + self.fc3 = nn.Linear(4096, 4096) + self.fc4 = nn.Linear(4096, 2304) + self.leaky_relu = nn.LeakyReLU(negative_slope=0.01) # 设置Leaky ReLU的负斜率 + + def forward(self, x): + x = self.leaky_relu(self.fc1(x)) + x = self.leaky_relu(self.fc2(x)) + x = self.leaky_relu(self.fc3(x)) + x = self.fc4(x) + return x + +# class CombineMLP(nn.Module): +# def __init__(self, input_dim=128, output_dim=64, hidden_dim=128): +# """ +# 构造一个 5 层 MLP 网络。 +# :param input_dim: 输入的特征维度,默认 128 +# :param output_dim: 输出的特征维度,默认 64 +# :param hidden_dim: 隐藏层维度,默认 128 +# """ +# super(CombineMLP, self).__init__() + +# # 定义 5 层 MLP +# self.fc1 = nn.Linear(input_dim, hidden_dim) #() +# self.fc2 = nn.Linear(hidden_dim, hidden_dim) +# self.fc3 = nn.Linear(hidden_dim, hidden_dim) +# self.fc4 = nn.Linear(hidden_dim, hidden_dim) +# self.fc5 = nn.Linear(hidden_dim, output_dim) # 最后一层映射到 64 + +# # 定义激活函数 +# # self.activation = nn.ReLU() +# self.activation = nn.LeakyReLU(negative_slope=0.01) # 默认负斜率为 0.01 + + +# def forward(self, x1, x2): +# """ +# 前向传播,支持两个输入 x1 和 x2 +# :param x1: 第一个输入,形状 (B, 64) +# :param x2: 第二个输入,形状 (B, 64) +# :return: 输出特征,形状 (B, 64) +# """ +# # 将两个输入拼接在一起 +# x = torch.cat([x1, x2], dim=-1) # 拼接后形状为 (B, 128) + +# # 依次通过 5 层 MLP 和激活函数 +# x = self.activation(self.fc1(x)) +# x = self.activation(self.fc2(x)) +# x = self.activation(self.fc3(x)) +# x = self.activation(self.fc4(x)) +# x = self.fc5(x) # 最后一层不使用激活函数(根据需求) + +# return x + + +class CombineMLP(nn.Module): + def __init__(self, input_dim=4*64*64*2, output_dim=4*64*64, hidden_dim=128, num_layers=5): + """ + 构造一个 5 层 MLP 网络。 + :param input_dim: 输入的特征维度,默认 128 + :param output_dim: 输出的特征维度,默认 64 + :param hidden_dim: 隐藏层维度,默认 128 + """ + super(CombineMLP, self).__init__() + + # 创建多个隐藏层 + layers = [] + for i in range(num_layers - 1): # 生成 num_layers-1 个隐藏层 + layers.append(nn.Linear(input_dim if i == 0 else hidden_dim, hidden_dim)) + layers.append(nn.ReLU()) + + # 输出层 + layers.append(nn.Linear(hidden_dim, output_dim)) + + # 将层组合成一个模块 + self.mlp = nn.Sequential(*layers) + + + def forward(self, x1, x2): + """ + 前向传播,支持两个输入 x1 和 x2 + :param x1: 第一个输入,形状 (1,16,4,64,64) + :param x2: 第二个输入,形状 (1,16,4,64,64) + :return: 输出特征,形状 (1,16,4,64,64) + """ + # import pdb; pdb.set_trace() + # 将两个输入拼接在一起 + x = torch.cat([x1, x2], dim=2) # 拼接后形状为 (1,16,8,64,64) + x = torch.flatten(x, start_dim=2) # Flatten to shape (batch_size, 16, 8*64*64) + x = self.mlp(x) # Apply MLP 1,16,16384 + x = x.reshape(x.size(0), x.size(1), 4, 64, 64) # Reshape back to (1, 16, 4, 64, 64) + + return x + + + +class HDRCtrlModeltmp(nn.Module): + def __init__(self): + super(HDRCtrlModel, self).__init__() + + # 定义卷积层 + self.conv_layer1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=4, padding=1) + self.conv_layer2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1) + + # 定义 MLP 模型 + self.mlp = MLP() + + def decompose_hdr(self, hdr_latents): + batch_size, channels, height, width = hdr_latents.shape + device = hdr_latents.device # 获取设备信息 + + # 生成 4×4 掩码 (batch_size, 1, 4, 4) + mask_small = torch.rand(batch_size, 1, 4, 4, device=device) # 从均匀分布生成随机掩码 + + # 将掩码调整为与输入相同的大小 (batch_size, 1, height, width) + mask = torch.nn.functional.interpolate(mask_small, size=(height, width), mode='bilinear', align_corners=False) + + # 保持连续值,不进行二值化 #! 注意此步操作, 注意可视化 random mask的结果... 首先可以可视化mask, 其次可视化 + mask = mask.expand(-1, channels, -1, -1) # 扩展掩码通道数以匹配 hdr_latents 的形状 + + # 应用 mask 生成 L1 和 L2 + hdr_latents_1 = hdr_latents * mask # L1 = 掩码部分 + hdr_latents_2 = hdr_latents * (1 - mask) # L2 = 非掩码部分 + + return hdr_latents_1, hdr_latents_2 + + def forward(self, hdr_latents): + # import pdb; pdb.set_trace() + # todo: mask get hdr1, hdr2; input hdr_latents(实际上暂时是ldr) + # 输入的形状为 (1, 16, 3, 256, 512),去掉多余的维度 + # import pdb; pdb.set_trace() + hdr_latents = hdr_latents.squeeze(0) # 变成 (16, 3, 256, 512) + + batch_size = hdr_latents.shape[0] + + # 转换为 NCHW 形式: (batch, channels, height, width) 输入之前numpy2tensor已经permute过了 + # hdr_latents = hdr_latents.permute(0, 3, 1, 2) #! 注意一下to tensor? (如何进行归一化的) 的时候已经 + # 进行卷积操作 + conv_output = self.conv_layer1(hdr_latents) #! 注意更改此处卷积!!! + conv_output = self.conv_layer2(conv_output) # (16, 3, 32, 64) + + # 截取前 32 列,得到最终形状 (16, 3, 32, 32) + hdr_latents = conv_output[:, :, :, :32] + # todo: decompose hdr + hdr_latents_1, hdr_latents_2 = self.decompose_hdr(hdr_latents) # [16, 3, 32, 32], [16, 3, 32, 32] + + # 将输出展平,准备输入到 MLP 中 + hdr_latents = hdr_latents.reshape(hdr_latents.size(0), -1) # [16, 3072] + hdr_latents_1 = hdr_latents_1.reshape(hdr_latents_1.size(0), -1) # [16, 3072] + hdr_latents_2 = hdr_latents_2.reshape(hdr_latents_2.size(0), -1) + + # 传递给 MLP + hdr_latents = self.mlp(hdr_latents) #(16, 2304) 3072 -> 2304 + hdr_latents_1 = self.mlp(hdr_latents_1) + hdr_latents_2 = self.mlp(hdr_latents_2) + + # 重新调整输出的形状 + hdr_latents = hdr_latents.reshape(batch_size, 3, 768) # reshape 输出为 (16, 3, 768) + hdr_latents_1 = hdr_latents_1.reshape(batch_size, 3, 768) + hdr_latents_2 = hdr_latents_2.reshape(batch_size, 3, 768) + + return hdr_latents, hdr_latents_1, hdr_latents_2 + +class HDRCtrlModel(nn.Module): + def __init__(self): + super(HDRCtrlModel, self).__init__() + + # 定义卷积层 + self.conv_layer1 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=4, padding=1) + self.conv_layer2 = nn.Conv2d(in_channels=3, out_channels=3, kernel_size=3, stride=2, padding=1) + + # 定义 MLP 模型 + self.mlp = MLP() + + def decompose_hdr(self, hdr_latents): # hdr_latents: 16,3,32,32 可以可视化一下这部分代码... + batch_size, channels, height, width = hdr_latents.shape + device = hdr_latents.device # 获取设备信息 + + # 生成 4×4 掩码 (batch_size, 1, 4, 4) + mask_small = torch.rand(batch_size, 1, 4, 4, device=device) # 从均匀分布生成随机掩码 + + threshold = 0.5 # 调节阈值,增加黑色部分比例 + mask_small = (mask_small > threshold).float() + # 将掩码调整为与输入相同的大小 (batch_size, 1, height, width) 16,1,32,32 + mask = torch.nn.functional.interpolate(mask_small, size=(height, width), mode='bilinear', align_corners=False) + + # import pdb; pdb.set_trace() + # 保持连续值,不进行二值化 #! 注意此步操作, 注意可视化 random mask的结果... 首先可以可视化mask, 其次可视化 + mask = mask.expand(-1, channels, -1, -1) # 扩展掩码通道数以匹配 hdr_latents 的形状 + + # 应用 mask 生成 L1 和 L2 + hdr_latents_1 = hdr_latents * mask # L1 = 掩码部分 + hdr_latents_2 = hdr_latents * (1 - mask) # L2 = 非掩码部分 + + return hdr_latents_1, hdr_latents_2 + + def blur_image(self, hdr_latents): + # 高斯模糊, 输入 (16,3,256,256) + processed_images = [] + kernel_size = (15, 15) + sigmaX = 10 + + # 对每张图像进行处理 + for i in range(hdr_latents.size(0)): # 遍历16张图像 + # 获取第i张图像 + image = hdr_latents[i].permute(1, 2, 0).cpu().numpy() # 将形状变为 (256, 256, 3) + + # 进行高斯模糊 + blurred_image = cv2.GaussianBlur(image, kernel_size, sigmaX) + + # 将图像缩放到 (32, 32, 3) + resized_image = cv2.resize(blurred_image, (32, 32), interpolation=cv2.INTER_AREA) + + # 将处理后的图像从 numpy 数组转换回 tensor + resized_image_tensor = torch.tensor(resized_image, dtype=torch.uint8, device=hdr_latents.device).permute(2, 0, 1) # 转回 (3, 32, 32) + + # 将处理后的图像添加到列表中 + processed_images.append(resized_image_tensor) + + # 将列表中的所有图像堆叠成一个 tensor + processed_images_tensor = torch.stack(processed_images) # 形状为 (16, 3, 32, 32) + + return processed_images_tensor + + def normalize_hdr(self, img): + img = img / 255.0 + return img * 2 -1 + + def forward(self, hdr_latents): + # import pdb; pdb.set_trace() + # todo: mask get hdr1, hdr2; input hdr_latents(实际上暂时是ldr) + # 输入的形状为 (n, 16, 3, 256, 256),去掉多余的维度 + # import pdb; pdb.set_trace() + # hdr_latents = hdr_latents.squeeze(0) # 变成 (16, 3, 256, 256) + batch_size_ori = hdr_latents.shape[0] + # frame_num = hdr_latents.shape[1] + + hdr_latents = rearrange(hdr_latents, 'b f c h w -> (b f) c h w') + batch_size = hdr_latents.shape[0] + # batch_size = hdr_latents.shape[0] + # 转换为 NCHW 形式: (batch, channels, height, width) 输入之前numpy2tensor已经permute过了 + # 高斯模糊 + hdr_latents = self.blur_image(hdr_latents) #(16,3,32,32) 可视化打印一下! + + # import pdb; pdb.set_trace() + # todo: decompose hdr + hdr_latents_1, hdr_latents_2 = self.decompose_hdr(hdr_latents) # [16, 3, 32, 32], [16, 3, 32, 32] + + # todo: 加一步 normalize /255 -> -1,1 + hdr_latents, hdr_latents_1, hdr_latents_2 = self.normalize_hdr(hdr_latents), self.normalize_hdr(hdr_latents_1), self.normalize_hdr(hdr_latents_2) + + # import pdb; pdb.set_trace() + # 将输出展平,准备输入到 MLP 中 + hdr_latents = hdr_latents.reshape(hdr_latents.size(0), -1) # [16, 3072] + hdr_latents_1 = hdr_latents_1.reshape(hdr_latents_1.size(0), -1) # [16, 3072] + hdr_latents_2 = hdr_latents_2.reshape(hdr_latents_2.size(0), -1) + + # 传递给 MLP + hdr_latents = self.mlp(hdr_latents) #(16, 2304) 3072 -> 2304 + hdr_latents_1 = self.mlp(hdr_latents_1) + hdr_latents_2 = self.mlp(hdr_latents_2) + + # 重新调整输出的形状 + hdr_latents = hdr_latents.reshape(batch_size, 3, 768) # reshape 输出为 (16*n, 3, 768) + hdr_latents_1 = hdr_latents_1.reshape(batch_size, 3, 768) + hdr_latents_2 = hdr_latents_2.reshape(batch_size, 3, 768) + + hdr_latents = rearrange(hdr_latents, '(b f) n c -> b f n c', b=batch_size_ori) + hdr_latents_1 = rearrange(hdr_latents_1, '(b f) n c -> b f n c', b=batch_size_ori) + hdr_latents_2 = rearrange(hdr_latents_2, '(b f) n c -> b f n c', b=batch_size_ori) + + #! 两个细节: 1. 仅有ldr, 需不需要concat hdr或线性变换 2. mask不同帧不一致 + return hdr_latents, hdr_latents_1, hdr_latents_2 # 3 x (b,16,3,768) + + +class InstructP2PVideoTrainer(DDIMLDMTextTraining): + def __init__( + self, *args, + cond_image_dropout=0.1, + cond_text_dropout=0.1, + cond_hdr_dropout=0.1, + prompt_type='output_prompt', + text_cfg=7.5, + img_cfg=1.2, + hdr_cfg=7.5, + hdr_rate=0.1, + ic_condition='bg', + hdr_train=False, + **kwargs + ): + super().__init__(*args, **kwargs) + self.hdr_train = hdr_train + if self.hdr_train: + self.hdr_encoder = HDRCtrlModel() + self.hdr_encoder = self.hdr_encoder.to(self.unet.device) + self.mlp = CombineMLP() + self.cond_hdr_dropout = cond_hdr_dropout + self.hdr_rate = hdr_rate + + self.cond_image_dropout = cond_image_dropout + self.cond_text_dropout = cond_text_dropout + + assert ic_condition in ['fg', 'bg'] + assert prompt_type in ['output_prompt', 'edit_prompt', 'mixed_prompt'] + self.prompt_type = prompt_type + self.ic_condition = ic_condition + + self.text_cfg = text_cfg + self.img_cfg = img_cfg + self.hdr_cfg = hdr_cfg + + #! 开启xformers训练设置 + # self.unet.enable_xformers_memory_efficient_attention() + # self.unet.enable_gradient_checkpointing() + + def encode_text(self, text): + with torch.cuda.amp.autocast(dtype=torch.float16): + encoded_text = super().encode_text(text) + return encoded_text + + def encode_image_to_latent(self, image): + # with torch.cuda.amp.autocast(dtype=torch.float16): + latent = super().encode_image_to_latent(image) + return latent + + # @torch.cuda.amp.autocast(dtype=torch.float16) + @torch.no_grad() + def get_prompt(self, batch, mode): + # if mode == 'train': + # if self.prompt_type == 'output_prompt': + # prompt = batch['output_prompt'] + # elif self.prompt_type == 'edit_prompt': # training的时候是edit prompt + # prompt = batch['edit_prompt'] + # elif self.prompt_type == 'mixed_prompt': + # if int(torch.rand(1)) > 0.5: + # prompt = batch['output_prompt'] + # else: + # prompt = batch['edit_prompt'] + # else: + # prompt = batch['output_prompt'] + prompt = batch['text_prompt'] + if not self.hdr_train: #! 如果hdr后续加进来text了, 还是需要? + if torch.rand(1).item() < self.cond_text_dropout: + prompt = 'change the background' + cond_text = self.encode_text(prompt) + if mode == 'train': + if torch.rand(1).item() < self.cond_text_dropout: + cond_text = torch.zeros_like(cond_text) + # import pdb; pdb.set_trace() + return cond_text + + # @torch.cuda.amp.autocast(dtype=torch.float16) + @torch.no_grad() + def encode_image_to_latent(self, image): + b, f, c, h, w = image.shape + image = rearrange(image, 'b f c h w -> (b f) c h w') + latent = super().encode_image_to_latent(image) + latent = rearrange(latent, '(b f) c h w -> b f c h w', b=b) + return latent + + # @torch.cuda.amp.autocast(dtype=torch.float16) + @torch.no_grad() + def decode_latent_to_image(self, latent): + b, f, c, h, w = latent.shape + latent = rearrange(latent, 'b f c h w -> (b f) c h w') + + image = [] + for latent_ in latent: + image_ = super().decode_latent_to_image(latent_[None]) + image.append(image_.sample) #! 注意一下这里 之前没报过错吗; -> 之前不是一个类 + image = torch.cat(image, dim=0) + # image = super().decode_latent_to_image(latent) + image = rearrange(image, '(b f) c h w -> b f c h w', b=b) + return image + + @torch.no_grad() + def get_cond_image(self, batch, mode): + # import pdb; pdb.set_trace() + cond_fg_image = batch['fg_video'] # 这边condition 就是 input_video了, 估计是concat或者ctrlnet + cond_fg_image = self.encode_image_to_latent(cond_fg_image) + if self.ic_condition == 'bg': + cond_bg_image = batch['bg_video'] + if torch.all(cond_bg_image == 0): + cond_bg_image = torch.zeros_like(cond_fg_image) #! 背景一定概率为0, 置为0.3 + else: + cond_bg_image = self.encode_image_to_latent(cond_bg_image) + cond_image = torch.cat((cond_fg_image, cond_bg_image), dim=2) #(1,16,8,64,64) + else: + cond_image = cond_fg_image + # test code: 可视化代码 + # from PIL import Image + # Image.fromarray(((batch['input_video'] + 1) / 2 * 255).byte()[0,0].permute(1,2,0).cpu().numpy()).save('img1.png') + + # ip2p does not scale cond image, so we unscale the cond image + # cond_image = self.encode_image_to_latent(cond_image) / self.scale_factor # 额 就是一个vae encode,没有缩放;这边不进行缩放吗? 啥意思呢 + + if mode == 'train': + # if int(torch.rand(1)) < self.cond_image_dropout: # 0.1的概率随机初始化, 应该是为了保障一个鲁棒性 难怪有的时候是全0, 不是代码的bug #! 艹 bug, 这么久才发现.... + if torch.rand(1).item() < self.cond_image_dropout: + cond_image = torch.zeros_like(cond_image) + return cond_image + + @torch.no_grad() + def get_diffused_image(self, batch, mode): + # import pdb; pdb.set_trace() + x = batch['tgt_video'] # 这边编辑的时候, 具体加噪和去噪的gt, 整个这套流程都是以编辑后, 即edited video作为输入 + # from PIL import Image + # Image.fromarray(((batch['edited_video'] + 1) / 2 * 255).byte()[0,0].permute(1,2,0).cpu().numpy()).save('img2.png') + b, *_ = x.shape + x = self.encode_image_to_latent(x) # (1, 16, 4, 32, 32), 经过了vae encode + eps = torch.randn_like(x) + + if mode == 'train': + t = torch.randint(0, self.num_timesteps, (b,), device=x.device).long() + else: + t = torch.full((b,), self.num_timesteps-1, device=x.device, dtype=torch.long) + x_t = self.add_noise(x, t, eps) # 加噪t步长 eps表示高斯噪声, 和scheduler的加噪 + + if self.prediction_type == 'epsilon': + return x_t, eps, t + else: + return x_t, x, t + + + @torch.no_grad() + def get_hdr_image(self, batch, mode): + x = batch['ldr_video'] # todo (16,3,256,512), float, tensor, device -> (1,16,3,256,256) 注意此时仅有ldr + # import pdb; pdb.set_trace() + hdr_latents, hdr_latents_1, hdr_latents_2 = self.hdr_encoder(x) + if mode == 'train': #! 考虑一下这个开不开, 因为后面要拉consistency loss + if torch.rand(1).item() < self.cond_hdr_dropout: + hdr_latents = torch.zeros_like(hdr_latents) + hdr_latents_1 = torch.zeros_like(hdr_latents_1) + hdr_latents_2 = torch.zeros_like(hdr_latents_2) + return hdr_latents, hdr_latents_1, hdr_latents_2 + + @torch.no_grad() # batch中需要加载mask + def get_mask(self, batch, mode, target): + # (1,16,1,512,512) + # import pdb; pdb.set_trace() + mask = batch['fg_mask'] # todo 返回mask (n,16,1,512,512) + bs = mask.shape[0] + target_height, target_width = target.shape[-2:] #(n,16,3,64,64) + + mask = rearrange(mask, 'b f c h w -> (b f) c h w') + resized_mask = F.interpolate(mask, size=(target_height, target_width), mode='bilinear', align_corners=False) + # resized_mask = resized_mask.unsqueeze(0) + resized_mask = rearrange(resized_mask, '(b f) c h w -> b f c h w', b=bs) + if target.shape[2] != resized_mask.shape[2]: + resized_mask = resized_mask.expand(-1, -1, target.shape[2], -1, -1) # 匹配目标通道数 + + return resized_mask + + @torch.no_grad() + def process_batch(self, batch, mode): #! 可视化这边的image, 查看问题出在哪了。。。 √, 应该是randn_drop的事 + # import pdb; pdb.set_trace() + cond_image = self.get_cond_image(batch, mode) # 把输入的src image进行一个编码, 这边只有vae的encode, 且没有乘缩放的系数(ip2p本身没乘...) + diffused_image, target, t = self.get_diffused_image(batch, mode) # diffused_image: 经过了vae encode, 和scheduler的加噪,标准的降噪输入 + # target: 这边是epsilon目标, 因此还是拉成epsilon的损失;t: 训练阶段是随机的一个数值, 推理阶段一般都是1000 + prompt = self.get_prompt(batch, mode) + model_kwargs = { + 'encoder_hidden_states': prompt + } + # import pdb; pdb.set_trace() + if self.hdr_train: + hdr_image, hdr_image_1, hdr_image_2 = self.get_hdr_image(batch, mode) #(16,3,768) + fg_mask = self.get_mask(batch, mode, target) # 把原图像前景mask resize到target大小 + + model_kwargs = { + 'encoder_hidden_states': {'hdr_latents': hdr_image, 'encoder_hidden_states': prompt, 'hdr_latents_1': hdr_image_1, 'hdr_latents_2': hdr_image_2, 'fg_mask': fg_mask} + } + + + return { + 'diffused_input': diffused_image, # (1, 16, 4, 64, 64), 经过了vae encode, 和scheduler的加噪 + 'condition': cond_image, # 把输入的src image进行一个编码, 这边只有vae的encode, 且没有乘缩放的系数 (1,16,8,64,64) + 'target': target, # 这个是加到tgt video的高斯噪声 + 't': t, # 0~1000的一个时刻 + 'model_kwargs': model_kwargs, # 这边就是一个text_hidden_states + } + + def training_step(self, batch, batch_idx): #! 注意一下仅仅训motion layer + # import pdb; pdb.set_trace() + processed_batch = self.process_batch(batch, mode='train') #(1,16,3,256,256), 读取的序列化图片, 仅仅做了一个归一化操作 + diffused_input = processed_batch['diffused_input'] # (1,16,4,64,64), edit images, 经过了vae encode, 和scheduler的加噪 + condition = processed_batch['condition'] # (1,16,8,64,64) 把输入的src images进行一个编码, 这边只有vae的encode, 且没有乘缩放的系数 + target = processed_batch['target'] # (1,16,4,64,64), target是加入的高斯噪声 + t = processed_batch['t'] # [257], 一个0~1000的随机时刻 + + model_kwargs = processed_batch['model_kwargs'] # dict, 仅包含一项: encoder_hidden_states, [1, 77, 768] text_hidden_states + + model_input = torch.cat([diffused_input, condition], dim=2) # b, f, c, h, w [1,16,8,32,32] 这边是做的concat, 很多edit文章经典操作, 把两个东西concat起来 + #! 半精度 + # model_input = model_input.float() + # model_kwargs['encoder_hidden_states'] = model_kwargs['encoder_hidden_states'].half() + model_input = rearrange(model_input, 'b f c h w -> b c f h w') # [1,8,16,32,32] + + pred = self.unet(model_input, t, **model_kwargs).sample # (1,4,16,64,64) #! + pred = rearrange(pred, 'b c f h w -> b f c h w') # (1,16,4,64,64) #! + + if not self.hdr_train: + loss = self.get_loss(pred, target, t) # 0.320 + else: + fg_mask = model_kwargs['encoder_hidden_states']['fg_mask'] + loss = self.get_hdr_loss(fg_mask, pred, target) + ### add consistency loss ### + # todo: 三个相同的model_input, 不同的model_kwargs (注意stack到一起, attn里面的逻辑也得改...) + # if self.hdr_train: + # fg_mask = model_kwargs['encoder_hidden_states']['fg_mask'] + # hdr_latents = model_kwargs['encoder_hidden_states']['hdr_latents_1'] + # hdr_latents_1 = model_kwargs['encoder_hidden_states']['hdr_latents_1'] + # hdr_latents_2 = model_kwargs['encoder_hidden_states']['hdr_latents_2'] + + # model_input = torch.cat([diffused_input, condition], dim=2) + # model_input = rearrange(model_input, 'b f c h w -> b c f h w') + # model_input_1 = model_input.clone() + # model_input_2 = model_input.clone() + # model_input_all = torch.cat([model_input, model_input_1, model_input_2], dim=0) + + # prompt = model_kwargs['encoder_hidden_states']['encoder_hidden_states'] #(1*n,77,768) + # prompt_all = torch.cat([prompt, prompt, prompt], dim=0) #(3*n,77,768) + # # import pdb; pdb.set_trace() + # model_kwargs['encoder_hidden_states']['encoder_hidden_states'] = prompt_all + + # # import pdb; pdb.set_trace() + # hdr_latents_all = torch.cat([hdr_latents, hdr_latents_1, hdr_latents_2], dim=0) #(3*n,16,77,768) + # model_kwargs['encoder_hidden_states']['hdr_latents']=hdr_latents_all + # pred_all = self.unet(model_input_all, t, **model_kwargs).sample # (1,4,16,64,64) + # pred_all = rearrange(pred_all, 'b c f h w -> b f c h w') + + # pred, pred1, pred2 = pred_all.chunk(3, dim=0) + # loss_ori = self.get_hdr_loss(fg_mask, pred, target) + + # # 假设获得了L1, L2 + # # hdr_latents_1 = mask(hdr_latents) # 随机构造一个mask + 逻辑矫正 + # # model_kwargs['encoder_hidden_states']['hdr_latents']=hdr_latents_1 + # # pred1 = self.unet(model_input, t, **model_kwargs).sample # get L1下的预测值 (1,16,4,64,64) + # # pred1 = rearrange(pred1, 'b c f h w -> b f c h w') + + # # model_input = torch.cat([diffused_input, condition], dim=2) + # # model_input = rearrange(model_input, 'b f c h w -> b c f h w') + # # # hdr_latents_2 = 1-mask(hdr_latents) + # # model_kwargs['encoder_hidden_states']['hdr_latents']=hdr_latents_2 + # # pred2 = self.unet(model_input, t, **model_kwargs).sample # get L2下的预测值 + # # pred2 = rearrange(pred2, 'b c f h w -> b f c h w') + # # import pdb; pdb.set_trace() + # pred_combine = self.mlp(pred1, pred2) #! todo: 构造mlp loss 错了!! 搞对一下, 应该需要展平.... + # loss_c = self.get_hdr_loss(fg_mask, pred, pred_combine) + # # loss_c = MSELoss(mask*pred, mask*pred_conbine) # todo: change to函数, 逻辑矫正 + + # loss = loss_ori + self.hdr_rate * loss_c # 设一个系数, 好控制变化 + ### end ### + self.log('train_loss', loss, sync_dist=True) + + latent_pred = self.predict_x_0_from_x_t(pred, t, diffused_input) # (1,16,4,32,32) + image_pred = self.decode_latent_to_image(latent_pred) # 这边相当于是pred_x0了, (1,16,3,256,256) + drop_out = torch.all(condition == 0).item() + + res_dict = { + 'loss': loss, + 'pred': image_pred, + 'drop_out': drop_out, + 'time': t[0].item() + } + return res_dict + + @torch.no_grad() + @torch.cuda.amp.autocast(dtype=torch.bfloat16) + def validation_step(self, batch, batch_idx): # 没写好 可以先pass + # pass + # import pdb; pdb.set_trace() + if not self.hdr_train: + from .inference.inference import InferenceIP2PVideo + inf_pipe = InferenceIP2PVideo( + self.unet, + beta_start=self.scheduler.config.beta_start, + beta_end=self.scheduler.config.beta_end, + beta_schedule=self.scheduler.config.beta_schedule, + num_ddim_steps=20 + ) + # import pdb; pdb.set_trace() + processed_batch = self.process_batch(batch, mode='val') + diffused_input = torch.randn_like(processed_batch['diffused_input']) #(1,16,4,64,64) + + condition = processed_batch['condition'] # 这边其实留有一个接口给condition (1,16,8,64,64) + img_cond = condition + text_cond = processed_batch['model_kwargs']['encoder_hidden_states'] + # import pdb; pdb.set_trace() + res = inf_pipe( + latent = diffused_input, + text_cond = text_cond, + text_uncond = self.encode_text(['']), + img_cond = img_cond, + text_cfg = self.text_cfg, + img_cfg = self.img_cfg, + hdr_cfg = self.hdr_cfg + ) + + latent_pred = res['latent'] + image_pred = self.decode_latent_to_image(latent_pred) + res_dict = { + 'pred': image_pred, + } + else: + from .inference.inference import InferenceIP2PVideoHDR + inf_pipe = InferenceIP2PVideoHDR( + self.unet, + beta_start=self.scheduler.config.beta_start, + beta_end=self.scheduler.config.beta_end, + beta_schedule=self.scheduler.config.beta_schedule, + num_ddim_steps=20 + ) + # import pdb; pdb.set_trace() + processed_batch = self.process_batch(batch, mode='val') + diffused_input = torch.randn_like(processed_batch['diffused_input']) #(1,16,4,64,64) + + condition = processed_batch['condition'] # 这边其实留有一个接口给condition (1,16,8,64,64) + model_kwargs = processed_batch['model_kwargs'] + img_cond = condition + text_cond = model_kwargs['encoder_hidden_states']['encoder_hidden_states'] + hdr_cond = model_kwargs['encoder_hidden_states']['hdr_latents'] + + # import pdb; pdb.set_trace() + res = inf_pipe( + latent = diffused_input, + text_cond = text_cond, + text_uncond = self.encode_text(['']), + hdr_cond = hdr_cond, + img_cond = img_cond, + text_cfg = self.text_cfg, + img_cfg = self.img_cfg, + ) + + latent_pred = res['latent'] + image_pred = self.decode_latent_to_image(latent_pred) + res_dict = { + 'pred': image_pred, + } + return res_dict + + def configure_optimizers(self): + # optimizer = torch.optim.AdamW(self.unet.parameters(), lr=self.optim_args['lr']) + import bitsandbytes as bnb + params = [] + for name, p in self.unet.named_parameters(): + if ('transformer_in' in name) or ('temp_' in name): + # p.requires_grad = True + params.append(p) + else: + pass + # p.requires_grad = False + optimizer = bnb.optim.Adam8bit(params, lr=self.optim_args['lr'], betas=(0.9, 0.999)) + return optimizer + + def initialize_unet(self, unet_init_weights): + if unet_init_weights is not None: + print(f'INFO: initialize denoising UNet from {unet_init_weights}') + sd = torch.load(unet_init_weights, map_location='cpu') + model_sd = self.unet.state_dict() + # fit input conv size + for k in model_sd.keys(): + if k in sd.keys(): + pass + else: + # handling temporal layers + if (('temp_' in k) or ('transformer_in' in k)) and 'proj_out' in k: + # print(f'INFO: initialize {k} from {model_sd[k].shape} to zeros') + sd[k] = torch.zeros_like(model_sd[k]) + else: + # print(f'INFO: initialize {k} from {model_sd[k].shape} to random') + sd[k] = model_sd[k] + self.unet.load_state_dict(sd) + +class InstructP2PVideoTrainerTemporal(InstructP2PVideoTrainer): + def initialize_unet(self, unet_init_weights): # 这边对比上一级来说, 新加的部分在于 rewrite了unet的load函数 + if unet_init_weights is not None: + print(f'INFO: initialize denoising UNet from {unet_init_weights}') + sd_init_weights, motion_module_init_weights, iclight_init_weights = unet_init_weights + os.makedirs(sd_init_weights, exist_ok=True) + sd_init_weights, motion_module_init_weights, iclight_init_weights = f'models/{sd_init_weights}', f'models/{motion_module_init_weights}', f'models/{iclight_init_weights}' + + if not os.path.exists(sd_init_weights): + url = 'https://huggingface.co/stablediffusionapi/realistic-vision-v51/resolve/main/unet/diffusion_pytorch_model.safetensors' + download_url_to_file(url=url, dst=sd_init_weights) + if not os.path.exists(motion_module_init_weights): + url = 'https://huggingface.co/aleafy/RelightVid/resolve/main/relvid_mm_sd15_fbc.pth' + download_url_to_file(url=url, dst=motion_module_init_weights) + if not os.path.exists(iclight_init_weights): + url = 'https://huggingface.co/lllyasviel/ic-light/resolve/main/iclight_sd15_fbc.safetensors' + download_url_to_file(url=url, dst=iclight_init_weights) + + sd = load_file(sd_init_weights) #! 关于加载iclight的unet, 后面再加到yaml里面... 我甚至觉得只要改unet, vae和text其实都差不多 + + # sd = torch.load(sd_init_weights, map_location='cpu') # 注意debug看看这是啥 + 打印一下原有和加载的keys + if self.unet.use_motion_module: + motion_sd = torch.load(motion_module_init_weights, map_location='cpu') + assert len(sd) + len(motion_sd) == len(self.unet.state_dict()), f'Improper state dict length, got {len(sd) + len(motion_sd)} expected {len(self.unet.state_dict())}' #! 注意一下这行保证了加载的key至少在数量上是对应的; 这行的目的是self.unet是自己定义的 而这两个加载的是别的地方训练的(可能是diffusers中的) + sd.update(motion_sd) + + for k, v in self.unet.state_dict().items(): + if 'pos_encoder.pe' in k: # 这边是原来iv2v的代码 temporal_position_encoding_max_len, 设置为 32 + sd[k] = v # the size of pe may change, 主要是temporal layer的size会发生改变... √ 由于输入的max_len变了 + # if 'conv_in.weight' in k: #! tmp, 这里是test一下 + # sd[k] = v + else: + assert len(sd) == len(self.unet.state_dict()) + + self.unet.load_state_dict(sd) # 为什么这里可以完美适配? √ + # todo: 更改sd的conv_in.weight的shape到12; 更改函数forward, 支持多个输入cond; iclight的sd_offset加载进去; + unet = self.unet # saVe一下 + # 这里是更改conv_in的shape; #! 这边注意一下要改成3D版本的unet + with torch.no_grad(): + # new_conv_in = torch.nn.Conv2d(12, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) + new_conv_in = InflatedConv3d(12, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) + new_conv_in.weight.zero_() + new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) + new_conv_in.bias = unet.conv_in.bias + unet.conv_in = new_conv_in + + ###### -- 更改 forward函数 --- ##### + + # 这里是更改forward函数。 具体调用的部分在main后面,那里也得改 + # unet_original_forward = unet.forward + # def hooked_unet_forward(sample, timestep, encoder_hidden_states, **kwargs): + # c_concat = kwargs['cross_attention_kwargs']['concat_conds'].to(sample) # (1,8,67,120) + # c_concat = torch.cat([c_concat] * (sample.shape[0] // c_concat.shape[0]), dim=0) # (2,8,67,120) 应该是复制一份,用于cfg + # new_sample = torch.cat([sample, c_concat], dim=2) #(2,12,67,120) 这边还是在通道维度上进行的concat #! change 在第二维cat (2,1,12,67,120) + # # todo 这边中间可以加一个f的通道 b,c,f,h,w ; 另一种方式: 对于数据进行改变, 那么上述concat的代码也需要进行变换了... + # # new_sample = new_sample.unsqueeze(2) # (2,12,1,67,120) #! 这里需要change, 要在一输入之前就要更改他的维度, 因此前面concat也需要稍微改一下 不要在forward中增加f维度 (因为要依赖输入) + + # new_sample = rearrange(new_sample, 'b f c h w -> b c f h w') + # kwargs['cross_attention_kwargs'] = {} + # # return unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs) + # result = unet_original_forward(new_sample, timestep, encoder_hidden_states, **kwargs) + # # return (result[0].squeeze(2),) #! tmp + # return (rearrange(result[0], 'b c f h w -> b f c h w'),) + # unet.forward = hooked_unet_forward + + ##### -- 更改 forward函数 --- ##### + + # model_path = '/home/fy/Code/instruct-video-to-video/IC-Light/models/iclight_sd15_fbc.safetensors' + # 这里是加载iclight的lora weight + sd_offset = load_file(iclight_init_weights) + sd_origin = unet.state_dict() + keys = sd_origin.keys() + for k in sd_offset.keys(): + sd_origin[k] = sd_origin[k] + sd_offset[k] + # sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()} + self.unet.load_state_dict(sd_origin, strict=True) + del sd_offset, sd_origin, unet, keys + + # print(1) + # todo 试写一下iclight unet的加载方式 + # sd = load_file('/home/fy/Code/IC-Light/cache_models/models--stablediffusionapi--realistic-vision-v51/snapshots/19e3643d7d963c156d01537188ec08f0b79a514a/unet/diffusion_pytorch_model.safetensors') + + # debug: print参数 + # with open('logs/sd_keys.txt', 'w') as f: + # f.write("SD Keys:\n") + # for key in sd_ori.keys(): + # f.write(f"{key}\n") + + # unet_state_dict = self.unet.state_dict() + # with open('logs/unet_state_dict_keys.txt', 'w') as f: + # f.write("UNet State Dict Keys:\n") + # for key in unet_state_dict.keys(): + # f.write(f"{key}\n") + else: + with torch.no_grad(): + new_conv_in = InflatedConv3d(12, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding) + self.unet.conv_in = new_conv_in + + def configure_optimizers(self): # 决定了仅仅训练motion_module的参数 注意一下pl.Trainer独有的函数 + import bitsandbytes as bnb + motion_params = [] + remaining_params = [] + train_names = [] # for debug + for name, p in self.unet.named_parameters(): + if ('motion' in name): #! 哦哦 这里决定了哪些参数用于训练... 这里实际训练的只有motion相关参数 + motion_params.append(p) + train_names.append(name) + elif ('attentions' in name): + motion_params.append(p) + train_names.append(name) + else: + remaining_params.append(p) + # import pdb; pdb.set_trace() + optimizer = bnb.optim.Adam8bit([ + {'params': motion_params, 'lr': self.optim_args['lr']}, + ], betas=(0.9, 0.999)) + return optimizer + + +class InstructP2PVideoTrainerTemporalText(InstructP2PVideoTrainerTemporal): + def initialize_unet(self, unet_init_weights): # 这边对比上一级来说, 新加的部分在于 rewrite了unet的load函数 + if unet_init_weights is not None: + print(f'INFO: initialize denoising UNet from {unet_init_weights}') + sd_init_weights, motion_module_init_weights, iclight_init_weights = unet_init_weights + if self.base_path: + sd_init_weights = f"{self.base_path}/{sd_init_weights}" + if '.safetensors' in sd_init_weights: # .safetensors的加载方式 + sd = load_file(sd_init_weights) #! 关于加载iclight的unet, 后面再加到yaml里面... 我甚至觉得只要改unet, vae和text其实都差不多 + else: #'.ckpt'场景 + sd = torch.load(sd_init_weights, map_location='cpu') + + # sd = torch.load(sd_init_weights, map_location='cpu') # 注意debug看看这是啥 + 打印一下原有和加载的keys + if self.unet.use_motion_module: + motion_sd = torch.load(motion_module_init_weights, map_location='cpu') + assert len(sd) + len(motion_sd) == len(self.unet.state_dict()), f'Improper state dict length, got {len(sd) + len(motion_sd)} expected {len(self.unet.state_dict())}' #! 注意一下这行保证了加载的key至少在数量上是对应的; 这行的目的是self.unet是自己定义的 而这两个加载的是别的地方训练的(可能是diffusers中的) + sd.update(motion_sd) + + for k, v in self.unet.state_dict().items(): + if 'pos_encoder.pe' in k: # 这边是原来iv2v的代码 temporal_position_encoding_max_len, 设置为 32 + sd[k] = v # the size of pe may change, 主要是temporal layer的size会发生改变... √ 由于输入的max_len变了 + # if 'conv_in.weight' in k: #! tmp, 这里是test一下 + # sd[k] = v + else: + assert len(sd) == len(self.unet.state_dict()) + + self.unet.load_state_dict(sd) # 为什么这里可以完美适配? √ + # todo: 更改sd的conv_in.weight的shape到12; 更改函数forward, 支持多个输入cond; iclight的sd_offset加载进去; + unet = self.unet # saVe一下 + # 这里是更改conv_in的shape; #! 这边注意一下要改成3D版本的unet + with torch.no_grad(): + # new_conv_in = torch.nn.Conv2d(12, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) + new_conv_in = InflatedConv3d(8, unet.conv_in.out_channels, unet.conv_in.kernel_size, unet.conv_in.stride, unet.conv_in.padding) + new_conv_in.weight.zero_() + new_conv_in.weight[:, :4, :, :].copy_(unet.conv_in.weight) + new_conv_in.bias = unet.conv_in.bias + unet.conv_in = new_conv_in + + + # model_path = '/home/fy/Code/instruct-video-to-video/IC-Light/models/iclight_sd15_fbc.safetensors' + # 这里是加载iclight的lora weight + sd_offset = load_file(iclight_init_weights) + sd_origin = unet.state_dict() + keys = sd_origin.keys() + for k in sd_offset.keys(): + sd_origin[k] = sd_origin[k] + sd_offset[k] + # sd_merged = {k: sd_origin[k] + sd_offset[k] for k in sd_origin.keys()} + self.unet.load_state_dict(sd_origin, strict=True) + del sd_offset, sd_origin, unet, keys + + else: + with torch.no_grad(): + new_conv_in = InflatedConv3d(8, self.unet.conv_in.out_channels, self.unet.conv_in.kernel_size, self.unet.conv_in.stride, self.unet.conv_in.padding) + self.unet.conv_in = new_conv_in diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..7fc64015cddb88fac7f5cf402303d34a42a8fbb1 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,23 @@ +torch==2.1.2 --index-url https://download.pytorch.org/whl/cu118 +torchvision==0.16.2 --index-url https://download.pytorch.org/whl/cu118 +diffusers==0.30.1 +transformers==4.44.2 +opencv-python +safetensors +pillow==10.2.0 +imageio +einops +peft +gradio==3.50.2 +protobuf==3.20 +accelerate +wandb +pytorch-lightning +opencv-contrib-python +omegaconf +open-clip-torch +jsonlines +diff-match-patch +git+https://github.com/openai/CLIP.git +deepspeed +tqdm diff --git a/static_fg_sync_bg_visualization_fy/14_22_100fps.mp4 b/static_fg_sync_bg_visualization_fy/14_22_100fps.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..bd6a6e0abdc01a518a9be0183cbfbed8d15faa4d Binary files /dev/null and b/static_fg_sync_bg_visualization_fy/14_22_100fps.mp4 differ diff --git a/static_fg_sync_bg_visualization_fy/14_55_100fps.mp4 b/static_fg_sync_bg_visualization_fy/14_55_100fps.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..a604fa1259374523768e4f87ed054412e7110e55 Binary files /dev/null and b/static_fg_sync_bg_visualization_fy/14_55_100fps.mp4 differ diff --git a/static_fg_sync_bg_visualization_fy/15_27_100fps.mp4 b/static_fg_sync_bg_visualization_fy/15_27_100fps.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..0ca6af6c2b4c5f07c7eee2beca3dec1ff0bd2faf Binary files /dev/null and b/static_fg_sync_bg_visualization_fy/15_27_100fps.mp4 differ diff --git a/static_fg_sync_bg_visualization_fy/18_23_100fps.mp4 b/static_fg_sync_bg_visualization_fy/18_23_100fps.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..3892fc6111175750400c33d12b8c38da559eeb40 Binary files /dev/null and b/static_fg_sync_bg_visualization_fy/18_23_100fps.mp4 differ diff --git a/static_fg_sync_bg_visualization_fy/18_33_100fps.mp4 b/static_fg_sync_bg_visualization_fy/18_33_100fps.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6e72bff8b265a6326fcf8066a82491d6ae2aa5e6 Binary files /dev/null and b/static_fg_sync_bg_visualization_fy/18_33_100fps.mp4 differ diff --git a/static_fg_sync_bg_visualization_fy/22_39_100fps.mp4 b/static_fg_sync_bg_visualization_fy/22_39_100fps.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..ea7bb24255740fa6716ab33ad794dc29d8cbec69 Binary files /dev/null and b/static_fg_sync_bg_visualization_fy/22_39_100fps.mp4 differ diff --git a/static_fg_sync_bg_visualization_fy/22_59_100fps.mp4 b/static_fg_sync_bg_visualization_fy/22_59_100fps.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..7bc2ecb84835735e3206de1d0515503f06374fb0 Binary files /dev/null and b/static_fg_sync_bg_visualization_fy/22_59_100fps.mp4 differ diff --git a/static_fg_sync_bg_visualization_fy/9_10_100fps.mp4 b/static_fg_sync_bg_visualization_fy/9_10_100fps.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..13fb8be990f57b76fb47b217f7785709e7f27abf Binary files /dev/null and b/static_fg_sync_bg_visualization_fy/9_10_100fps.mp4 differ diff --git a/static_fg_sync_bg_visualization_fy/9_14_100fps.mp4 b/static_fg_sync_bg_visualization_fy/9_14_100fps.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..119f75d83647474ce82841d90f8f90e68d86d600 Binary files /dev/null and b/static_fg_sync_bg_visualization_fy/9_14_100fps.mp4 differ diff --git a/static_fg_sync_bg_visualization_fy/9_8_100fps.mp4 b/static_fg_sync_bg_visualization_fy/9_8_100fps.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..27fcc83ffbb8fa1ef3ddd8d8b59311ad731276d3 Binary files /dev/null and b/static_fg_sync_bg_visualization_fy/9_8_100fps.mp4 differ diff --git a/static_fg_sync_bg_visualization_fy/9_9_100fps.mp4 b/static_fg_sync_bg_visualization_fy/9_9_100fps.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..bb301e82162db1be18b8eca56101740b48eaf7b5 Binary files /dev/null and b/static_fg_sync_bg_visualization_fy/9_9_100fps.mp4 differ