pandaphd commited on
Commit
1ae4e5b
·
1 Parent(s): e2515d4

Removed <file> from Git LFS tracking

Browse files
.gitattributes CHANGED
@@ -1,3 +1,36 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f978134ea372378fb27d2c9aaeb7db0a8d814207997bdad9ed8f368783d0a857
3
- size 1593
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ * !text !filter !merge !diff
README.md CHANGED
@@ -1,3 +1,14 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9c83176a74a800ceebc4069a48b824b4c1a7b2f06d02ff5959e63eebc2a8d222
3
- size 331
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Generative Photography
3
+ emoji: 📈
4
+ colorFrom: blue
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 5.20.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: cc-by-nc-nd-4.0
11
+ short_description: Demo for Generative Photography
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,3 +1,152 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:14e58bab9ed2b6eac8619e2b9c3c3ff03bf4689406c28de8eb49237f6f25c23b
3
- size 8306
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import torch
4
+ from inference_bokehK import load_models as load_bokeh_models, run_inference as run_bokeh_inference, OmegaConf
5
+ from inference_focal_length import load_models as load_focal_models, run_inference as run_focal_inference
6
+ from inference_shutter_speed import load_models as load_shutter_models, run_inference as run_shutter_inference
7
+ from inference_color_temperature import load_models as load_color_models, run_inference as run_color_inference
8
+
9
+ torch.manual_seed(42)
10
+
11
+ bokeh_cfg = OmegaConf.load("configs/inference_genphoto/adv3_256_384_genphoto_relora_bokehK.yaml")
12
+ bokeh_pipeline, bokeh_device = load_bokeh_models(bokeh_cfg)
13
+
14
+ focal_cfg = OmegaConf.load("configs/inference_genphoto/adv3_256_384_genphoto_relora_focal_length.yaml")
15
+ focal_pipeline, focal_device = load_focal_models(focal_cfg)
16
+
17
+ shutter_cfg = OmegaConf.load("configs/inference_genphoto/adv3_256_384_genphoto_relora_shutter_speed.yaml")
18
+ shutter_pipeline, shutter_device = load_shutter_models(shutter_cfg)
19
+
20
+ color_cfg = OmegaConf.load("configs/inference_genphoto/adv3_256_384_genphoto_relora_color_temperature.yaml")
21
+ color_pipeline, color_device = load_color_models(color_cfg)
22
+
23
+
24
+ def generate_bokeh_video(base_scene, bokehK_list):
25
+ try:
26
+ torch.manual_seed(42)
27
+ if len(json.loads(bokehK_list)) != 5:
28
+ raise ValueError("Exactly 5 Bokeh K values required")
29
+ return run_bokeh_inference(
30
+ pipeline=bokeh_pipeline, tokenizer=bokeh_pipeline.tokenizer,
31
+ text_encoder=bokeh_pipeline.text_encoder, base_scene=base_scene,
32
+ bokehK_list=bokehK_list, device=bokeh_device
33
+ )
34
+ except Exception as e:
35
+ return f"Error: {str(e)}"
36
+
37
+ def generate_focal_video(base_scene, focal_length_list):
38
+ try:
39
+ torch.manual_seed(42)
40
+ if len(json.loads(focal_length_list)) != 5:
41
+ raise ValueError("Exactly 5 focal length values required")
42
+ return run_focal_inference(
43
+ pipeline=focal_pipeline, tokenizer=focal_pipeline.tokenizer,
44
+ text_encoder=focal_pipeline.text_encoder, base_scene=base_scene,
45
+ focal_length_list=focal_length_list, device=focal_device
46
+ )
47
+ except Exception as e:
48
+ return f"Error: {str(e)}"
49
+
50
+ def generate_shutter_video(base_scene, shutter_speed_list):
51
+ try:
52
+ torch.manual_seed(42)
53
+ if len(json.loads(shutter_speed_list)) != 5:
54
+ raise ValueError("Exactly 5 shutter speed values required")
55
+ return run_shutter_inference(
56
+ pipeline=shutter_pipeline, tokenizer=shutter_pipeline.tokenizer,
57
+ text_encoder=shutter_pipeline.text_encoder, base_scene=base_scene,
58
+ shutter_speed_list=shutter_speed_list, device=shutter_device
59
+ )
60
+ except Exception as e:
61
+ return f"Error: {str(e)}"
62
+
63
+ def generate_color_video(base_scene, color_temperature_list):
64
+ try:
65
+ torch.manual_seed(42)
66
+ if len(json.loads(color_temperature_list)) != 5:
67
+ raise ValueError("Exactly 5 color temperature values required")
68
+ return run_color_inference(
69
+ pipeline=color_pipeline, tokenizer=color_pipeline.tokenizer,
70
+ text_encoder=color_pipeline.text_encoder, base_scene=base_scene,
71
+ color_temperature_list=color_temperature_list, device=color_device
72
+ )
73
+ except Exception as e:
74
+ return f"Error: {str(e)}"
75
+
76
+
77
+
78
+ bokeh_examples = [
79
+ ["A variety of potted plants are displayed on a window sill, with some of them placed in yellow and white cups. The plants are arranged in different sizes and shapes, creating a visually appealing display.", "[18.0, 14.0, 10.0, 6.0, 2.0]"],
80
+ ["A colorful backpack with a floral pattern is sitting on a table next to a computer monitor.", "[2.3, 5.8, 10.2, 14.8, 24.9]"]
81
+ ]
82
+
83
+ focal_examples = [
84
+ ["A small office cubicle with a desk.", "[25.1, 36.1, 47.1, 58.1, 69.1]"],
85
+ ["A large white couch in a living room.", "[55.0, 46.0, 37.0, 28.0, 25.0]"]
86
+ ]
87
+
88
+ shutter_examples = [
89
+ ["A brown and orange leather handbag.", "[0.11, 0.22, 0.33, 0.44, 0.55]"],
90
+ ["A variety of potted plants.", "[0.2, 0.49, 0.69, 0.75, 0.89]"]
91
+ ]
92
+
93
+ color_examples = [
94
+ ["A blue sky with mountains.", "[5455.0, 5155.0, 5555.0, 6555.0, 7555.0]"],
95
+ ["A red couch in front of a window.", "[3500.0, 5500.0, 6500.0, 7500.0, 8500.0]"]
96
+ ]
97
+
98
+
99
+ with gr.Blocks(title="Generative Photography") as demo:
100
+ gr.Markdown("# **Generative Photography: Scene-Consistent Camera Control for Realistic Text-to-Image Synthesis** ")
101
+
102
+ with gr.Tabs():
103
+ with gr.Tab("BokehK Effect"):
104
+ gr.Markdown("### Generate Frames with Bokeh Blur Effect")
105
+ with gr.Row():
106
+ with gr.Column():
107
+ scene_input_bokeh = gr.Textbox(label="Scene Description", placeholder="Describe the scene you want to generate...")
108
+ bokeh_input = gr.Textbox(label="Bokeh Blur Values", placeholder="Enter 5 comma-separated values from 1-30, e.g., [2.44, 8.3, 10.1, 17.2, 24.0]")
109
+ submit_bokeh = gr.Button("Generate Video")
110
+ with gr.Column():
111
+ video_output_bokeh = gr.Video(label="Generated Video")
112
+ gr.Examples(bokeh_examples, [scene_input_bokeh, bokeh_input], [video_output_bokeh], generate_bokeh_video)
113
+ submit_bokeh.click(generate_bokeh_video, [scene_input_bokeh, bokeh_input], [video_output_bokeh])
114
+
115
+ with gr.Tab("Focal Length Effect"):
116
+ gr.Markdown("### Generate Frames with Focal Length Effect")
117
+ with gr.Row():
118
+ with gr.Column():
119
+ scene_input_focal = gr.Textbox(label="Scene Description", placeholder="Describe the scene you want to generate...")
120
+ focal_input = gr.Textbox(label="Focal Length Values", placeholder="Enter 5 comma-separated values from 24-70, e.g., [25.1, 30.2, 33.3, 40.8, 54.0]")
121
+ submit_focal = gr.Button("Generate Video")
122
+ with gr.Column():
123
+ video_output_focal = gr.Video(label="Generated Video")
124
+ gr.Examples(focal_examples, [scene_input_focal, focal_input], [video_output_focal], generate_focal_video)
125
+ submit_focal.click(generate_focal_video, [scene_input_focal, focal_input], [video_output_focal])
126
+
127
+ with gr.Tab("Shutter Speed Effect"):
128
+ gr.Markdown("### Generate Frames with Shutter Speed Effect")
129
+ with gr.Row():
130
+ with gr.Column():
131
+ scene_input_shutter = gr.Textbox(label="Scene Description", placeholder="Describe the scene you want to generate...")
132
+ shutter_input = gr.Textbox(label="Shutter Speed Values", placeholder="Enter 5 comma-separated values from 0.1-1.0, e.g., [0.15, 0.32, 0.53, 0.62, 0.82]")
133
+ submit_shutter = gr.Button("Generate Video")
134
+ with gr.Column():
135
+ video_output_shutter = gr.Video(label="Generated Video")
136
+ gr.Examples(shutter_examples, [scene_input_shutter, shutter_input], [video_output_shutter], generate_shutter_video)
137
+ submit_shutter.click(generate_shutter_video, [scene_input_shutter, shutter_input], [video_output_shutter])
138
+
139
+ with gr.Tab("Color Temperature Effect"):
140
+ gr.Markdown("### Generate Frames with Color Temperature Effect")
141
+ with gr.Row():
142
+ with gr.Column():
143
+ scene_input_color = gr.Textbox(label="Scene Description", placeholder="Describe the scene you want to generate...")
144
+ color_input = gr.Textbox(label="Color Temperature Values", placeholder="Enter 5 comma-separated values from 2000-10000, e.g., [3001.3, 4000.2, 4400.34, 5488.23, 8888.82]")
145
+ submit_color = gr.Button("Generate Video")
146
+ with gr.Column():
147
+ video_output_color = gr.Video(label="Generated Video")
148
+ gr.Examples(color_examples, [scene_input_color, color_input], [video_output_color], generate_color_video)
149
+ submit_color.click(generate_color_video, [scene_input_color, color_input], [video_output_color])
150
+
151
+ if __name__ == "__main__":
152
+ demo.launch(share=True)
configs/inference_genphoto/adv3_256_384_genphoto_relora_bokehK.yaml CHANGED
@@ -1,3 +1,66 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a74bacc98940eb895b1ac635f5e8b4fabb811d98c8a067ece44c0ac4ff460842
3
- size 1823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: "inference_output/genphoto_bokehK"
2
+
3
+ pretrained_model_repo: "pandaphd/generative_photography"
4
+ pretrained_model_path: "stable-diffusion-v1-5"
5
+
6
+ unet_subfolder: "unet_merged"
7
+
8
+ camera_adaptor_ckpt: "weights/checkpoint-bokehK.ckpt"
9
+ lora_ckpt: "weights/RealEstate10K_LoRA.ckpt"
10
+ motion_module_ckpt: "weights/v3_sd15_mm.ckpt"
11
+
12
+ lora_rank: 2
13
+ lora_scale: 1.0
14
+ motion_lora_rank: 0
15
+ motion_lora_scale: 1.0
16
+
17
+ unet_additional_kwargs:
18
+ use_motion_module : true
19
+ motion_module_resolutions : [ 1,2,4,8 ]
20
+ unet_use_cross_frame_attention : false
21
+ unet_use_temporal_attention : false
22
+ motion_module_mid_block: false
23
+ motion_module_type: Vanilla
24
+ motion_module_kwargs:
25
+ num_attention_heads : 8
26
+ num_transformer_block : 1
27
+ attention_block_types : [ "Temporal_Self", "Temporal_Self" ]
28
+ temporal_position_encoding : true
29
+ temporal_position_encoding_max_len : 32
30
+ temporal_attention_dim_div : 1
31
+ zero_initialize : false
32
+
33
+ camera_encoder_kwargs:
34
+ downscale_factor: 8
35
+ channels: [320, 640, 1280, 1280]
36
+ nums_rb: 2
37
+ cin: 384
38
+ ksize: 1
39
+ sk: true
40
+ use_conv: false
41
+ compression_factor: 1
42
+ temporal_attention_nhead: 8
43
+ attention_block_types: ["Temporal_Self", ]
44
+ temporal_position_encoding: true
45
+ temporal_position_encoding_max_len: 16
46
+
47
+ attention_processor_kwargs:
48
+ add_spatial: false
49
+ spatial_attn_names: 'attn1'
50
+ add_temporal: true
51
+ temporal_attn_names: '0'
52
+ camera_feature_dimensions: [320, 640, 1280, 1280]
53
+ query_condition: true
54
+ key_value_condition: true
55
+ scale: 1.0
56
+
57
+ noise_scheduler_kwargs:
58
+ num_train_timesteps: 1000
59
+ beta_start: 0.00085
60
+ beta_end: 0.012
61
+ beta_schedule: "linear"
62
+ steps_offset: 1
63
+ clip_sample: false
64
+
65
+ num_workers: 8
66
+ global_seed: 42
configs/inference_genphoto/adv3_256_384_genphoto_relora_color_temperature.yaml CHANGED
@@ -1,3 +1,66 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d6f6e2911a8e440f4796db8ae67b919659067b859bacd7575953da6c2b8bfb2d
3
- size 1845
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: "inference_output/genphoto_color_temperature"
2
+
3
+ pretrained_model_repo: "pandaphd/generative_photography"
4
+ pretrained_model_path: "stable-diffusion-v1-5"
5
+
6
+ unet_subfolder: "unet_merged"
7
+
8
+ camera_adaptor_ckpt: "weights/checkpoint-color_temperature.ckpt"
9
+ lora_ckpt: "weights/RealEstate10K_LoRA.ckpt"
10
+ motion_module_ckpt: "weights/v3_sd15_mm.ckpt"
11
+
12
+ lora_rank: 2
13
+ lora_scale: 1.0
14
+ motion_lora_rank: 0
15
+ motion_lora_scale: 1.0
16
+
17
+ unet_additional_kwargs:
18
+ use_motion_module : true
19
+ motion_module_resolutions : [ 1,2,4,8 ]
20
+ unet_use_cross_frame_attention : false
21
+ unet_use_temporal_attention : false
22
+ motion_module_mid_block: false
23
+ motion_module_type: Vanilla
24
+ motion_module_kwargs:
25
+ num_attention_heads : 8
26
+ num_transformer_block : 1
27
+ attention_block_types : [ "Temporal_Self", "Temporal_Self" ]
28
+ temporal_position_encoding : true
29
+ temporal_position_encoding_max_len : 32
30
+ temporal_attention_dim_div : 1
31
+ zero_initialize : false
32
+
33
+
34
+ camera_encoder_kwargs:
35
+ downscale_factor: 8
36
+ channels: [320, 640, 1280, 1280]
37
+ nums_rb: 2
38
+ cin: 384
39
+ ksize: 1
40
+ sk: true
41
+ use_conv: false
42
+ compression_factor: 1
43
+ temporal_attention_nhead: 8
44
+ attention_block_types: ["Temporal_Self", ]
45
+ temporal_position_encoding: true
46
+ temporal_position_encoding_max_len: 16
47
+ attention_processor_kwargs:
48
+ add_spatial: false
49
+ spatial_attn_names: 'attn1'
50
+ add_temporal: true
51
+ temporal_attn_names: '0'
52
+ camera_feature_dimensions: [320, 640, 1280, 1280]
53
+ query_condition: true
54
+ key_value_condition: true
55
+ scale: 1.0
56
+ noise_scheduler_kwargs:
57
+ num_train_timesteps: 1000
58
+ beta_start: 0.00085
59
+ beta_end: 0.012
60
+ beta_schedule: "linear"
61
+ steps_offset: 1
62
+ clip_sample: false
63
+
64
+
65
+ num_workers: 8
66
+ global_seed: 42
configs/inference_genphoto/adv3_256_384_genphoto_relora_focal_length.yaml CHANGED
@@ -1,3 +1,65 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1c8c9059792e1ca206c44edd1cb29765c5ddb1f54551a1b1fc7010bf292420a8
3
- size 1834
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: "inference_output/genphoto_focal_length"
2
+
3
+ pretrained_model_repo: "pandaphd/generative_photography"
4
+ pretrained_model_path: "stable-diffusion-v1-5"
5
+
6
+ unet_subfolder: "unet_merged"
7
+
8
+ camera_adaptor_ckpt: "weights/checkpoint-focal_length.ckpt"
9
+ lora_ckpt: "weights/RealEstate10K_LoRA.ckpt"
10
+ motion_module_ckpt: "weights/v3_sd15_mm.ckpt"
11
+
12
+ lora_rank: 2
13
+ lora_scale: 1.0
14
+ motion_lora_rank: 0
15
+ motion_lora_scale: 1.0
16
+
17
+ unet_additional_kwargs:
18
+ use_motion_module : true
19
+ motion_module_resolutions : [ 1,2,4,8 ]
20
+ unet_use_cross_frame_attention : false
21
+ unet_use_temporal_attention : false
22
+ motion_module_mid_block: false
23
+ motion_module_type: Vanilla
24
+ motion_module_kwargs:
25
+ num_attention_heads : 8
26
+ num_transformer_block : 1
27
+ attention_block_types : [ "Temporal_Self", "Temporal_Self" ]
28
+ temporal_position_encoding : true
29
+ temporal_position_encoding_max_len : 32
30
+ temporal_attention_dim_div : 1
31
+ zero_initialize : false
32
+
33
+ camera_encoder_kwargs:
34
+ downscale_factor: 8
35
+ channels: [320, 640, 1280, 1280]
36
+ nums_rb: 2
37
+ cin: 384
38
+ ksize: 1
39
+ sk: true
40
+ use_conv: false
41
+ compression_factor: 1
42
+ temporal_attention_nhead: 8
43
+ attention_block_types: ["Temporal_Self", ]
44
+ temporal_position_encoding: true
45
+ temporal_position_encoding_max_len: 16
46
+ attention_processor_kwargs:
47
+ add_spatial: false
48
+ spatial_attn_names: 'attn1'
49
+ add_temporal: true
50
+ temporal_attn_names: '0'
51
+ camera_feature_dimensions: [320, 640, 1280, 1280]
52
+ query_condition: true
53
+ key_value_condition: true
54
+ scale: 1.0
55
+ noise_scheduler_kwargs:
56
+ num_train_timesteps: 1000
57
+ beta_start: 0.00085
58
+ beta_end: 0.012
59
+ beta_schedule: "linear"
60
+ steps_offset: 1
61
+ clip_sample: false
62
+
63
+
64
+ num_workers: 8
65
+ global_seed: 42
configs/inference_genphoto/adv3_256_384_genphoto_relora_shutter_speed.yaml CHANGED
@@ -1,3 +1,66 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:36c7d618e29249ce9086f5424f9a718b0faac002edd19ee0fd0335b85fdc8b7f
3
- size 1837
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: "inference_output/genphoto_shutter_speed"
2
+
3
+ pretrained_model_repo: "pandaphd/generative_photography"
4
+ pretrained_model_path: "stable-diffusion-v1-5"
5
+
6
+ unet_subfolder: "unet_merged"
7
+
8
+ camera_adaptor_ckpt: "weights/checkpoint-shutter_speed.ckpt"
9
+ lora_ckpt: "weights/RealEstate10K_LoRA.ckpt"
10
+ motion_module_ckpt: "weights/v3_sd15_mm.ckpt"
11
+
12
+ lora_rank: 2
13
+ lora_scale: 1.0
14
+ motion_lora_rank: 0
15
+ motion_lora_scale: 1.0
16
+
17
+ unet_additional_kwargs:
18
+ use_motion_module : true
19
+ motion_module_resolutions : [ 1,2,4,8 ]
20
+ unet_use_cross_frame_attention : false
21
+ unet_use_temporal_attention : false
22
+ motion_module_mid_block: false
23
+ motion_module_type: Vanilla
24
+ motion_module_kwargs:
25
+ num_attention_heads : 8
26
+ num_transformer_block : 1
27
+ attention_block_types : [ "Temporal_Self", "Temporal_Self" ]
28
+ temporal_position_encoding : true
29
+ temporal_position_encoding_max_len : 32
30
+ temporal_attention_dim_div : 1
31
+ zero_initialize : false
32
+
33
+
34
+ camera_encoder_kwargs:
35
+ downscale_factor: 8
36
+ channels: [320, 640, 1280, 1280]
37
+ nums_rb: 2
38
+ cin: 384
39
+ ksize: 1
40
+ sk: true
41
+ use_conv: false
42
+ compression_factor: 1
43
+ temporal_attention_nhead: 8
44
+ attention_block_types: ["Temporal_Self", ]
45
+ temporal_position_encoding: true
46
+ temporal_position_encoding_max_len: 16
47
+ attention_processor_kwargs:
48
+ add_spatial: false
49
+ spatial_attn_names: 'attn1'
50
+ add_temporal: true
51
+ temporal_attn_names: '0'
52
+ camera_feature_dimensions: [320, 640, 1280, 1280]
53
+ query_condition: true
54
+ key_value_condition: true
55
+ scale: 1.0
56
+ noise_scheduler_kwargs:
57
+ num_train_timesteps: 1000
58
+ beta_start: 0.00085
59
+ beta_end: 0.012
60
+ beta_schedule: "linear"
61
+ steps_offset: 1
62
+ clip_sample: false
63
+
64
+
65
+ num_workers: 8
66
+ global_seed: 42
environment.yaml CHANGED
@@ -1,3 +1,27 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a55fe5d623a3450e046bd7d0d095676d9d2ca62d36d19cfda8e9307007634970
3
- size 435
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: genphoto
2
+ channels:
3
+ - pytorch
4
+ - nvidia
5
+ dependencies:
6
+ - python=3.10
7
+ - pytorch=2.1.1
8
+ - torchvision=0.16.1
9
+ - torchaudio=2.1.1
10
+ - pytorch-cuda=12.1
11
+ - pip
12
+ - pip:
13
+ - diffusers==0.24.0
14
+ - xformers==0.0.23
15
+ - imageio==2.36.0
16
+ - imageio[ffmpeg]
17
+ - opencv-python
18
+ - transformers
19
+ - gdown
20
+ - einops
21
+ - decord
22
+ - omegaconf
23
+ - safetensors
24
+ - gradio
25
+ - wandb
26
+ - triton
27
+ - termcolor
genphoto/data/dataset.py CHANGED
@@ -1,3 +1,950 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4c11d5ea01a3dd35a0987915a62ffb2c4c967ff4c81d2c9f0fe876f2daa93aad
3
- size 38885
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import json
4
+ import torch
5
+ import math
6
+ import torch.nn as nn
7
+ import torchvision.transforms as transforms
8
+ import torch.nn.functional as F
9
+ import numpy as np
10
+ from torch.utils.data.dataset import Dataset
11
+ from packaging import version as pver
12
+ import cv2
13
+ from PIL import Image
14
+ from einops import rearrange
15
+ from transformers import pipeline, CLIPTextModel, CLIPTokenizer
16
+
17
+ import sys
18
+ sys.path.append('/home/yuan418/data/project/Generative_Photography/genphoto/data/BokehMe/')
19
+ from classical_renderer.scatter import ModuleRenderScatter
20
+
21
+
22
+
23
+ #### for shutter speed ####
24
+ def create_shutter_speed_embedding(shutter_speed_values, target_height, target_width, base_exposure=0.5):
25
+ """
26
+ Create an shutter_speed embedding tensor using a constant fwc value.
27
+ Args:
28
+ - shutter_speed_values: Tensor of shape [f, 1] containing shutter_speed values for each frame.
29
+ - H: Height of the image.
30
+ - W: Width of the image.
31
+ - base_exposure: A base exposure value to normalize brightness (defaults to 0.18 as a common base exposure level).
32
+
33
+ Returns:
34
+ - shutter_speed_embedding: Tensor of shape [f, 1, H, W] where each pixel is scaled based on the shutter_speed values.
35
+ """
36
+ f = shutter_speed_values.shape[0]
37
+
38
+ # Set a constant full well capacity (fwc)
39
+ fwc = 32000 # Constant value for full well capacity
40
+
41
+ # Calculate scale based on EV and sensor full well capacity (fwc)
42
+ scales = (shutter_speed_values / base_exposure) * (fwc / (fwc + 0.0001))
43
+
44
+ # Reshape and expand to match image dimensions
45
+ scales = scales.unsqueeze(2).unsqueeze(3).expand(f, 3, target_height, target_width)
46
+
47
+ # Use scales to create the final shutter_speed embedding
48
+ shutter_speed_embedding = scales # Shape [f, 3, H, W]
49
+ return shutter_speed_embedding
50
+
51
+
52
+ def sensor_image_simulation_numpy(avg_PPP, photon_flux, fwc, Nbits, gain=1):
53
+ min_val = 0
54
+ max_val = 2 ** Nbits - 1
55
+ theta = photon_flux * (avg_PPP / (np.mean(photon_flux) + 0.0001))
56
+ theta = np.clip(theta, 0, fwc)
57
+ theta = np.round(theta * gain * max_val / fwc)
58
+ theta = np.clip(theta, min_val, max_val)
59
+ theta = theta.astype(np.float32)
60
+ return theta
61
+
62
+
63
+ class CameraShutterSpeed(Dataset):
64
+ def __init__(
65
+ self,
66
+ root_path,
67
+ annotation_json,
68
+ sample_n_frames=5,
69
+ sample_size=[256, 384],
70
+ is_Train=True,
71
+ ):
72
+ self.root_path = root_path
73
+ self.sample_n_frames = sample_n_frames
74
+ self.dataset = json.load(open(os.path.join(root_path, annotation_json), 'r'))
75
+ self.length = len(self.dataset)
76
+ self.is_Train = is_Train
77
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
78
+ self.sample_size = sample_size
79
+
80
+ pixel_transforms = [transforms.Resize(sample_size),
81
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
82
+
83
+ self.pixel_transforms = pixel_transforms
84
+ self.tokenizer = CLIPTokenizer.from_pretrained("/home/yuan418/data/project/stable-diffusion-v1-5/", subfolder="tokenizer")
85
+ self.text_encoder = CLIPTextModel.from_pretrained("/home/yuan418/data/project/stable-diffusion-v1-5/", subfolder="text_encoder")
86
+
87
+ def load_image_reader(self, idx):
88
+ image_dict = self.dataset[idx]
89
+ image_path = os.path.join(self.root_path, image_dict['base_image_path'])
90
+ image_reader = cv2.imread(image_path)
91
+ image_reader = cv2.cvtColor(image_reader, cv2.COLOR_BGR2RGB)
92
+ image_caption = image_dict['caption']
93
+
94
+ if self.is_Train:
95
+ mean = 0.48
96
+ std_dev = 0.25
97
+ shutter_speed_values = [random.gauss(mean, std_dev) for _ in range(self.sample_n_frames)]
98
+ shutter_speed_values = [max(0.1, min(1.0, ev)) for ev in shutter_speed_values]
99
+ print('train shutter_speed values', shutter_speed_values)
100
+
101
+ else:
102
+ shutter_speed_list_str = image_dict['shutter_speed_list']
103
+ shutter_speed_values = json.loads(shutter_speed_list_str)
104
+ print('validation shutter_speed_values', shutter_speed_values)
105
+
106
+ shutter_speed_values = torch.tensor(shutter_speed_values).unsqueeze(1)
107
+ return image_path, image_reader, image_caption, shutter_speed_values
108
+
109
+
110
+ def get_batch(self, idx):
111
+ image_path, image_reader, image_caption, shutter_speed_values = self.load_image_reader(idx)
112
+
113
+ total_frames = len(shutter_speed_values)
114
+ if total_frames < 3:
115
+ raise ValueError("less than 3 frames")
116
+
117
+ # Generate prompts for each shutter speed value and append shutter speed information to caption
118
+ prompts = []
119
+ for ss in shutter_speed_values:
120
+ prompt = f"<exposure: {ss.item()}>"
121
+ prompts.append(prompt)
122
+
123
+ # Tokenize prompts and encode to get embeddings
124
+ with torch.no_grad():
125
+ prompt_ids = self.tokenizer(
126
+ prompts, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
127
+ ).input_ids
128
+ # print('tokenizer model_max_length', self.tokenizer.model_max_length)
129
+
130
+ encoder_hidden_states = self.text_encoder(input_ids=prompt_ids).last_hidden_state # Shape: (f, sequence_length, hidden_size)
131
+
132
+ # print('encoder_hidden_states shape', encoder_hidden_states.shape)
133
+
134
+ # Calculate differences between consecutive embeddings (ignoring sequence_length)
135
+ differences = []
136
+ for i in range(1, encoder_hidden_states.size(0)):
137
+ diff = encoder_hidden_states[i] - encoder_hidden_states[i - 1]
138
+ diff = diff.unsqueeze(0)
139
+ differences.append(diff)
140
+
141
+ # Add the difference between the last and the first embedding
142
+ final_diff = encoder_hidden_states[-1] - encoder_hidden_states[0]
143
+ final_diff = final_diff.unsqueeze(0)
144
+ differences.append(final_diff)
145
+
146
+ # Concatenate differences along the batch dimension (f-1)
147
+ concatenated_differences = torch.cat(differences, dim=0)
148
+ # print('concatenated_differences shape', concatenated_differences.shape) # f 77 768
149
+
150
+ frame = concatenated_differences.size(0)
151
+
152
+ concatenated_differences = torch.cat(differences, dim=0)
153
+
154
+ # Current shape: (f, 77, 768) Pad the second dimension (77) to 128
155
+ pad_length = 128 - concatenated_differences.size(1)
156
+ if pad_length > 0:
157
+ # Pad along the second dimension (77 -> 128), pad only on the right side
158
+ concatenated_differences_padded = F.pad(concatenated_differences, (0, 0, 0, pad_length))
159
+
160
+ ## ccl = constrative camera learning
161
+ ccl_embedding = concatenated_differences_padded.reshape(frame, self.sample_size[0], self.sample_size[1])
162
+ ccl_embedding = ccl_embedding.unsqueeze(1)
163
+ ccl_embedding = ccl_embedding.expand(-1, 3, -1, -1)
164
+
165
+ # Now handle the sensor image simulation
166
+ fwc = random.uniform(19000, 64000)
167
+ pixel_values = []
168
+ for ee in shutter_speed_values:
169
+ avg_PPP = (0.6 * ee.item() + 0.1) * fwc
170
+ img_sim = sensor_image_simulation_numpy(avg_PPP, image_reader, fwc, Nbits=8, gain=1)
171
+ pixel_values.append(img_sim)
172
+ pixel_values = np.stack(pixel_values, axis=0)
173
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() / 255.
174
+
175
+ # Create shutter_speed embedding and concatenate it with CCL embedding
176
+ shutter_speed_embedding = create_shutter_speed_embedding(shutter_speed_values, self.sample_size[0], self.sample_size[1])
177
+
178
+ camera_embedding = torch.cat((shutter_speed_embedding, ccl_embedding), dim=1)
179
+ # print('camera_embedding shape', camera_embedding.shape)
180
+
181
+ return pixel_values, image_caption, camera_embedding, shutter_speed_values
182
+
183
+ def __len__(self):
184
+ return self.length
185
+
186
+ def __getitem__(self, idx):
187
+ while True:
188
+ try:
189
+ video, video_caption, camera_embedding, shutter_speed_values = self.get_batch(idx)
190
+ break
191
+ except Exception as e:
192
+ idx = random.randint(0, self.length - 1)
193
+
194
+ for transform in self.pixel_transforms:
195
+ video = transform(video)
196
+
197
+ sample = dict(pixel_values=video, text=video_caption, camera_embedding=camera_embedding, shutter_speed_values=shutter_speed_values)
198
+
199
+ return sample
200
+
201
+
202
+
203
+
204
+
205
+
206
+
207
+
208
+ #### for focal length ####
209
+ def crop_focal_length(image_path, base_focal_length, target_focal_length, target_height, target_width, sensor_height=24.0, sensor_width=36.0):
210
+ img = Image.open(image_path)
211
+ width, height = img.size
212
+
213
+ # Calculate base and target FOV
214
+ base_x_fov = 2.0 * math.atan(sensor_width * 0.5 / base_focal_length)
215
+ base_y_fov = 2.0 * math.atan(sensor_height * 0.5 / base_focal_length)
216
+
217
+ target_x_fov = 2.0 * math.atan(sensor_width * 0.5 / target_focal_length)
218
+ target_y_fov = 2.0 * math.atan(sensor_height * 0.5 / target_focal_length)
219
+
220
+ # Calculate crop ratio, use the smaller ratio to maintain aspect ratio
221
+ crop_ratio = min(target_x_fov / base_x_fov, target_y_fov / base_y_fov)
222
+
223
+ crop_width = int(round(crop_ratio * width))
224
+ crop_height = int(round(crop_ratio * height))
225
+
226
+ # Ensure crop dimensions are within valid bounds
227
+ crop_width = max(1, min(width, crop_width))
228
+ crop_height = max(1, min(height, crop_height))
229
+
230
+ # Crop coordinates
231
+ left = int((width - crop_width) / 2)
232
+ top = int((height - crop_height) / 2)
233
+ right = int((width + crop_width) / 2)
234
+ bottom = int((height + crop_height) / 2)
235
+
236
+ # Crop the image
237
+ zoomed_img = img.crop((left, top, right, bottom))
238
+
239
+ # Resize the cropped image to target resolution
240
+ resized_img = zoomed_img.resize((target_width, target_height), Image.Resampling.LANCZOS)
241
+
242
+ # Convert the PIL image to a numpy array
243
+ resized_img_np = np.array(resized_img).astype(np.float32)
244
+
245
+ return resized_img_np
246
+
247
+
248
+ def create_focal_length_embedding(focal_length_values, base_focal_length, target_height, target_width, sensor_height=24.0, sensor_width=36.0):
249
+ device = 'cpu'
250
+ focal_length_values = focal_length_values.to(device)
251
+
252
+ f = focal_length_values.shape[0] # Number of frames
253
+
254
+ # Convert constants to tensors to perform operations with focal_length_values
255
+ sensor_width = torch.tensor(sensor_width, device=device)
256
+ sensor_height = torch.tensor(sensor_height, device=device)
257
+ base_focal_length = torch.tensor(base_focal_length, device=device)
258
+
259
+ # Calculate the FOV for the base focal length (min_focal_length)
260
+ base_fov_x = 2.0 * torch.atan(sensor_width * 0.5 / base_focal_length)
261
+ base_fov_y = 2.0 * torch.atan(sensor_height * 0.5 / base_focal_length)
262
+
263
+ # Calculate the FOV for each focal length in focal_length_values
264
+ target_fov_x = 2.0 * torch.atan(sensor_width * 0.5 / focal_length_values)
265
+ target_fov_y = 2.0 * torch.atan(sensor_height * 0.5 / focal_length_values)
266
+
267
+ # Calculate crop ratio: how much of the image is cropped at the current focal length
268
+ crop_ratio_xs = target_fov_x / base_fov_x # Crop ratio for horizontal axis
269
+ crop_ratio_ys = target_fov_y / base_fov_y # Crop ratio for vertical axis
270
+
271
+ # Get the center of the image
272
+ center_h, center_w = target_height // 2, target_width // 2
273
+
274
+ # Initialize a mask tensor with zeros on CPU
275
+ focal_length_embedding = torch.zeros((f, 3, target_height, target_width), dtype=torch.float32) # Shape [f, 3, H, W]
276
+
277
+ # Fill the center region with 1 based on the calculated crop dimensions
278
+ for i in range(f):
279
+ # Crop dimensions calculated using rounded float values
280
+ crop_h = torch.round(crop_ratio_ys[i] * target_height).int().item() # Rounded cropped height for the current frame
281
+ crop_w = torch.round(crop_ratio_xs[i] * target_width).int().item() # Rounded cropped width for the current frame
282
+
283
+ # Ensure the cropped dimensions are within valid bounds
284
+ crop_h = max(1, min(target_height, crop_h))
285
+ crop_w = max(1, min(target_width, crop_w))
286
+
287
+ # Set the center region of the focal_length embedding to 1 for the current frame
288
+ focal_length_embedding[i, :,
289
+ center_h - crop_h // 2: center_h + crop_h // 2,
290
+ center_w - crop_w // 2: center_w + crop_w // 2] = 1.0
291
+
292
+ return focal_length_embedding
293
+
294
+
295
+ class CameraFocalLength(Dataset):
296
+ def __init__(
297
+ self,
298
+ root_path,
299
+ annotation_json,
300
+ sample_n_frames=5,
301
+ sample_size=[256, 384],
302
+ is_Train=True,
303
+ ):
304
+ self.root_path = root_path
305
+ self.sample_n_frames = sample_n_frames
306
+ self.dataset = json.load(open(os.path.join(root_path, annotation_json), 'r'))
307
+ self.length = len(self.dataset)
308
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
309
+ self.sample_size = sample_size
310
+ pixel_transforms = [transforms.Resize(sample_size),
311
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
312
+
313
+ self.pixel_transforms = pixel_transforms
314
+ self.is_Train = is_Train
315
+ self.tokenizer = CLIPTokenizer.from_pretrained("/home/yuan418/data/project/stable-diffusion-v1-5/", subfolder="tokenizer")
316
+ self.text_encoder = CLIPTextModel.from_pretrained("/home/yuan418/data/project/stable-diffusion-v1-5/", subfolder="text_encoder")
317
+
318
+
319
+ def load_image_reader(self, idx):
320
+ image_dict = self.dataset[idx]
321
+
322
+ image_path = os.path.join(self.root_path, image_dict['base_image_path'])
323
+ image_reader = cv2.imread(image_path)
324
+
325
+ image_caption = image_dict['caption']
326
+
327
+ if self.is_Train:
328
+ focal_length_values = [random.uniform(24.0, 70.0) for _ in range(self.sample_n_frames)]
329
+ print('train focal_length_values', focal_length_values)
330
+ else:
331
+ focal_length_list_str = image_dict['focal_length_list']
332
+ focal_length_values = json.loads(focal_length_list_str)
333
+ print('validation focal_length_values', focal_length_values)
334
+
335
+ focal_length_values = torch.tensor(focal_length_values).unsqueeze(1)
336
+
337
+ return image_path, image_reader, image_caption, focal_length_values
338
+
339
+
340
+ def get_batch(self, idx):
341
+ image_path, image_reader, image_caption, focal_length_values = self.load_image_reader(idx)
342
+
343
+ total_frames = len(focal_length_values)
344
+ if total_frames < 3:
345
+ raise ValueError("less than 3 frames")
346
+
347
+ # Generate prompts for each fl value and append fl information to caption
348
+ prompts = []
349
+ for fl in focal_length_values:
350
+ prompt = f"<focal length: {fl.item()}>"
351
+ prompts.append(prompt)
352
+
353
+ # Tokenize prompts and encode to get embeddings
354
+ with torch.no_grad():
355
+ prompt_ids = self.tokenizer(
356
+ prompts, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
357
+ ).input_ids
358
+
359
+ encoder_hidden_states = self.text_encoder(input_ids=prompt_ids).last_hidden_state # Shape: (f, sequence_length, hidden_size)
360
+ # print('encoder_hidden_states shape', encoder_hidden_states.shape)
361
+
362
+ # Calculate differences between consecutive embeddings (ignoring sequence_length)
363
+ differences = []
364
+ for i in range(1, encoder_hidden_states.size(0)):
365
+ diff = encoder_hidden_states[i] - encoder_hidden_states[i - 1]
366
+ diff = diff.unsqueeze(0)
367
+ differences.append(diff)
368
+
369
+ # Add the difference between the last and the first embedding
370
+ final_diff = encoder_hidden_states[-1] - encoder_hidden_states[0]
371
+ final_diff = final_diff.unsqueeze(0)
372
+ differences.append(final_diff)
373
+
374
+ # Concatenate differences along the batch dimension (f-1)
375
+ concatenated_differences = torch.cat(differences, dim=0)
376
+ # print('concatenated_differences shape', concatenated_differences.shape) # f 77 768
377
+
378
+ frame = concatenated_differences.size(0)
379
+
380
+ # Concatenate differences along the batch dimension (f)
381
+ concatenated_differences = torch.cat(differences, dim=0)
382
+
383
+ # Current shape: (f, 77, 768), Pad the second dimension (77) to 128
384
+ pad_length = 128 - concatenated_differences.size(1)
385
+ if pad_length > 0:
386
+ # Pad along the second dimension (77 -> 128), pad only on the right side
387
+ concatenated_differences_padded = F.pad(concatenated_differences, (0, 0, 0, pad_length))
388
+
389
+ ## CCL = constrative camera learning
390
+ ccl_embedding = concatenated_differences_padded.reshape(frame, self.sample_size[0], self.sample_size[1])
391
+
392
+ ccl_embedding = ccl_embedding.unsqueeze(1)
393
+ ccl_embedding = ccl_embedding.expand(-1, 3, -1, -1)
394
+ # print('ccl_embedding shape', ccl_embedding.shape)
395
+
396
+ pixel_values = []
397
+ for ff in focal_length_values:
398
+ img_sim = crop_focal_length(image_path=image_path, base_focal_length=24.0, target_focal_length=ff, target_height=self.sample_size[0], target_width=self.sample_size[1], sensor_height=24.0, sensor_width=36.0)
399
+
400
+ pixel_values.append(img_sim)
401
+ # save_path = os.path.join(self.root_path, f"simulated_img_focal_length_{fl.item():.2f}.png")
402
+ # cv2.imwrite(save_path, img_sim)
403
+ # print(f"Saved image: {save_path}")
404
+
405
+ pixel_values = np.stack(pixel_values, axis=0)
406
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() / 255.
407
+
408
+ focal_length_embedding = create_focal_length_embedding(focal_length_values, base_focal_length=24.0, target_height=self.sample_size[0], target_width=self.sample_size[1])
409
+ # print('focal_length_embedding shape', focal_length_embedding.shape)
410
+
411
+ camera_embedding = torch.cat((focal_length_embedding, ccl_embedding), dim=1)
412
+ # print('camera_embedding shape', camera_embedding.shape)
413
+
414
+ return pixel_values, image_caption, camera_embedding, focal_length_values
415
+
416
+ def __len__(self):
417
+ return self.length
418
+
419
+ def __getitem__(self, idx):
420
+ while True:
421
+ try:
422
+ video, video_caption, camera_embedding, focal_length_values = self.get_batch(idx)
423
+ break
424
+ except Exception as e:
425
+ idx = random.randint(0, self.length - 1)
426
+
427
+ for transform in self.pixel_transforms:
428
+ video = transform(video)
429
+
430
+ sample = dict(pixel_values=video, text=video_caption, camera_embedding=camera_embedding, focal_length_values=focal_length_values)
431
+
432
+ return sample
433
+
434
+
435
+
436
+
437
+
438
+
439
+
440
+ #### for color temperature ####
441
+ def kelvin_to_rgb(kelvin):
442
+ temp = kelvin / 100.0
443
+
444
+ if temp <= 66:
445
+ red = 255
446
+ green = 99.4708025861 * np.log(temp) - 161.1195681661 if temp > 0 else 0
447
+ if temp <= 19:
448
+ blue = 0
449
+ else:
450
+ blue = 138.5177312231 * np.log(temp - 10) - 305.0447927307
451
+
452
+ elif 66<temp<=88:
453
+ red = 0.5 * (255 + 329.698727446 * ((temp - 60) ** -0.19332047592))
454
+ green = 0.5 * (288.1221695283 * ((temp - 60) ** -0.1155148492) + (99.4708025861 * np.log(temp) - 161.1195681661 if temp > 0 else 0))
455
+ blue = 0.5 * (138.5177312231 * np.log(temp - 10) - 305.0447927307 + 255)
456
+
457
+ else:
458
+ red = 329.698727446 * ((temp - 60) ** -0.19332047592)
459
+ green = 288.1221695283 * ((temp - 60) ** -0.1155148492)
460
+ blue = 255
461
+
462
+ return np.array([red, green, blue], dtype=np.float32) / 255.0
463
+
464
+
465
+
466
+ def create_color_temperature_embedding(color_temperature_values, target_height, target_width, min_color_temperature=2000, max_color_temperature=10000):
467
+ """
468
+ Create an color_temperature embedding tensor based on color temperature.
469
+ Args:
470
+ - color_temperature_values: Tensor of shape [f, 1] containing color_temperature values for each frame.
471
+ - target_height: Height of the image.
472
+ - target_width: Width of the image.
473
+ - min_color_temperature: Minimum color_temperature value for normalization.
474
+ - max_color_temperature: Maximum color_temperature value for normalization.
475
+ Returns:
476
+ - color_temperature_embedding: Tensor of shape [f, 3, target_height, target_width] for RGB channel scaling.
477
+ """
478
+ f = color_temperature_values.shape[0]
479
+ rgb_factors = []
480
+
481
+ # Compute RGB factors based on kelvin_to_rgb function
482
+ for ct in color_temperature_values.squeeze():
483
+ kelvin = min_color_temperature + (ct * (max_color_temperature - min_color_temperature)) # Map normalized color_temperature to actual Kelvin
484
+ rgb = kelvin_to_rgb(kelvin)
485
+ rgb_factors.append(rgb)
486
+
487
+ # Convert to tensor and expand to target dimensions
488
+ rgb_factors = torch.tensor(rgb_factors).float() # [f, 3]
489
+ rgb_factors = rgb_factors.unsqueeze(2).unsqueeze(3) # [f, 3, 1, 1]
490
+ color_temperature_embedding = rgb_factors.expand(f, 3, target_height, target_width) # [f, 3, target_height, target_width]
491
+ return color_temperature_embedding
492
+
493
+
494
+
495
+ def kelvin_to_rgb_smooth(kelvin):
496
+ temp = kelvin / 100.0
497
+
498
+ if temp <= 66:
499
+ red = 255
500
+ green = 99.4708025861 * np.log(temp) - 161.1195681661 if temp > 0 else 0
501
+ if temp <= 19:
502
+ blue = 0
503
+ else:
504
+ blue = 138.5177312231 * np.log(temp - 10) - 305.0447927307
505
+
506
+ elif 66<temp<=88:
507
+ red = 0.5 * (255 + 329.698727446 * ((temp - 60) ** -0.19332047592))
508
+ green = 0.5 * (288.1221695283 * ((temp - 60) ** -0.1155148492) + (99.4708025861 * np.log(temp) - 161.1195681661 if temp > 0 else 0))
509
+ blue = 0.5 * (138.5177312231 * np.log(temp - 10) - 305.0447927307 + 255)
510
+
511
+ else:
512
+ red = 329.698727446 * ((temp - 60) ** -0.19332047592)
513
+ green = 288.1221695283 * ((temp - 60) ** -0.1155148492)
514
+ blue = 255
515
+
516
+ red = np.clip(red, 0, 255)
517
+ green = np.clip(green, 0, 255)
518
+ blue = np.clip(blue, 0, 255)
519
+ balance_rgb = np.array([red, green, blue], dtype=np.float32)
520
+
521
+ return balance_rgb
522
+
523
+
524
+ def interpolate_white_balance(image, kelvin):
525
+
526
+ balance_rgb = kelvin_to_rgb_smooth(kelvin.item())
527
+ image = image.astype(np.float32)
528
+
529
+ r, g, b = cv2.split(image)
530
+ r = r * (balance_rgb[0] / 255.0)
531
+ g = g * (balance_rgb[1] / 255.0)
532
+ b = b * (balance_rgb[2] / 255.0)
533
+
534
+ balanced_image = cv2.merge([r,g,b])
535
+ balanced_image = np.clip(balanced_image, 0, 255).astype(np.uint8)
536
+
537
+ return balanced_image
538
+
539
+
540
+ class CameraColorTemperature(Dataset):
541
+ def __init__(
542
+ self,
543
+ root_path,
544
+ annotation_json,
545
+ sample_n_frames=5,
546
+ sample_size=[256, 384],
547
+ is_Train=True,
548
+ ):
549
+ self.root_path = root_path
550
+ self.sample_n_frames = sample_n_frames
551
+ self.dataset = json.load(open(os.path.join(root_path, annotation_json), 'r'))
552
+
553
+ self.length = len(self.dataset)
554
+ self.is_Train = is_Train
555
+
556
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
557
+ self.sample_size = sample_size
558
+
559
+ pixel_transforms = [transforms.Resize(sample_size),
560
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
561
+
562
+ self.pixel_transforms = pixel_transforms
563
+ self.tokenizer = CLIPTokenizer.from_pretrained("/home/yuan418/data/project/stable-diffusion-v1-5/", subfolder="tokenizer")
564
+ self.text_encoder = CLIPTextModel.from_pretrained("/home/yuan418/data/project/stable-diffusion-v1-5/", subfolder="text_encoder")
565
+
566
+ def load_image_reader(self, idx):
567
+ image_dict = self.dataset[idx]
568
+
569
+ image_path = os.path.join(self.root_path, image_dict['base_image_path'])
570
+ image_reader = cv2.imread(image_path)
571
+ image_reader = cv2.cvtColor(image_reader, cv2.COLOR_BGR2RGB)
572
+
573
+ image_caption = image_dict['caption']
574
+
575
+ if self.is_Train:
576
+ color_temperature_values = [random.uniform(2000.0, 10000.0) for _ in range(self.sample_n_frames)]
577
+ print('train color_temperature values', color_temperature_values)
578
+
579
+ else:
580
+ color_temperature_list_str = image_dict['color_temperature_list']
581
+ color_temperature_values = json.loads(color_temperature_list_str)
582
+ print('validation color_temperature_values', color_temperature_values)
583
+
584
+ color_temperature_values = torch.tensor(color_temperature_values).unsqueeze(1)
585
+ return image_path, image_reader, image_caption, color_temperature_values
586
+
587
+
588
+ def get_batch(self, idx):
589
+ image_path, image_reader, image_caption, color_temperature_values = self.load_image_reader(idx)
590
+
591
+ total_frames = len(color_temperature_values)
592
+ if total_frames < 3:
593
+ raise ValueError("less than 3 frames")
594
+
595
+ # Generate prompts for each color_temperature value and append color_temperature information to caption
596
+ prompts = []
597
+ for cc in color_temperature_values:
598
+ prompt = f"<color temperature: {cc.item()}>"
599
+ prompts.append(prompt)
600
+
601
+ # Tokenize prompts and encode to get embeddings
602
+ with torch.no_grad():
603
+ prompt_ids = self.tokenizer(
604
+ prompts, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
605
+ ).input_ids
606
+ # print('tokenizer model_max_length', self.tokenizer.model_max_length)
607
+
608
+ encoder_hidden_states = self.text_encoder(input_ids=prompt_ids).last_hidden_state # Shape: (f, sequence_length, hidden_size)
609
+
610
+ # print('encoder_hidden_states shape', encoder_hidden_states.shape)
611
+
612
+ # Calculate differences between consecutive embeddings (ignoring sequence_length)
613
+ differences = []
614
+ for i in range(1, encoder_hidden_states.size(0)):
615
+ diff = encoder_hidden_states[i] - encoder_hidden_states[i - 1]
616
+ diff = diff.unsqueeze(0)
617
+ differences.append(diff)
618
+
619
+ # Add the difference between the last and the first embedding
620
+ final_diff = encoder_hidden_states[-1] - encoder_hidden_states[0]
621
+ final_diff = final_diff.unsqueeze(0)
622
+ differences.append(final_diff)
623
+
624
+ # Concatenate differences along the batch dimension (f-1)
625
+ concatenated_differences = torch.cat(differences, dim=0)
626
+ # print('concatenated_differences shape', concatenated_differences.shape) # f 77 768
627
+
628
+ frame = concatenated_differences.size(0)
629
+
630
+ concatenated_differences = torch.cat(differences, dim=0)
631
+
632
+ # Current shape: (f, 77, 768), Pad the second dimension (77) to 128
633
+ pad_length = 128 - concatenated_differences.size(1)
634
+ if pad_length > 0:
635
+ # Pad along the second dimension (77 -> 128), pad only on the right side
636
+ concatenated_differences_padded = F.pad(concatenated_differences, (0, 0, 0, pad_length))
637
+
638
+ ccl_embedding = concatenated_differences_padded.reshape(frame, self.sample_size[0], self.sample_size[1])
639
+ ccl_embedding = ccl_embedding.unsqueeze(1)
640
+ ccl_embedding = ccl_embedding.expand(-1, 3, -1, -1)
641
+ # print('ccl_embedding shape', ccl_embedding.shape)
642
+
643
+ # Now handle the sensor image simulation
644
+ pixel_values = []
645
+ for aw in color_temperature_values:
646
+ img_sim = interpolate_white_balance(image_reader, aw)
647
+ pixel_values.append(img_sim)
648
+ pixel_values = np.stack(pixel_values, axis=0)
649
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() / 255.
650
+
651
+ # Create color_temperature embedding and concatenate it with CCL embedding
652
+ color_temperature_embedding = create_color_temperature_embedding(color_temperature_values, self.sample_size[0], self.sample_size[1])
653
+ # print('color_temperature_embedding shape', color_temperature_embedding.shape)
654
+
655
+ camera_embedding = torch.cat((color_temperature_embedding, ccl_embedding), dim=1)
656
+ # print('camera_embedding shape', camera_embedding.shape)
657
+
658
+ return pixel_values, image_caption, camera_embedding, color_temperature_values
659
+
660
+ def __len__(self):
661
+ return self.length
662
+
663
+ def __getitem__(self, idx):
664
+ while True:
665
+ try:
666
+ video, video_caption, camera_embedding, color_temperature_values = self.get_batch(idx)
667
+ break
668
+ except Exception as e:
669
+ idx = random.randint(0, self.length - 1)
670
+
671
+ for transform in self.pixel_transforms:
672
+ video = transform(video)
673
+
674
+ sample = dict(pixel_values=video, text=video_caption, camera_embedding=camera_embedding, color_temperature_values=color_temperature_values)
675
+
676
+ return sample
677
+
678
+
679
+
680
+
681
+
682
+
683
+
684
+
685
+ #### for bokeh (K is the blur parameter) ####
686
+ def create_bokehK_embedding(bokehK_values, target_height, target_width):
687
+ """
688
+ Creates a Bokeh embedding based on the given K values. The larger the K value,
689
+ the more the image is blurred.
690
+
691
+ Args:
692
+ bokehK_values (torch.Tensor): Tensor of K values for bokeh effect.
693
+ target_height (int): Desired height of the output embedding.
694
+ target_width (int): Desired width of the output embedding.
695
+ base_K (float): Base K value to control the minimum blur level.
696
+
697
+ Returns:
698
+ torch.Tensor: Bokeh embedding tensor. [f 3 h w]
699
+ """
700
+ f = bokehK_values.shape[0]
701
+ bokehK_embedding = torch.zeros((f, 3, target_height, target_width), dtype=bokehK_values.dtype)
702
+
703
+ for i in range(f):
704
+ K_value = bokehK_values[i].item()
705
+
706
+ kernel_size = max(K_value, 1)
707
+ sigma = K_value / 3.0
708
+
709
+ ax = np.linspace(-(kernel_size / 2), kernel_size / 2, int(np.ceil(kernel_size)))
710
+ xx, yy = np.meshgrid(ax, ax)
711
+ kernel = np.exp(-(xx ** 2 + yy ** 2) / (2 * sigma ** 2))
712
+ kernel /= np.sum(kernel)
713
+
714
+ scale = kernel[int(np.ceil(kernel_size) / 2), int(np.ceil(kernel_size) / 2)]
715
+ bokehK_embedding[i] = scale
716
+
717
+ return bokehK_embedding
718
+
719
+
720
+ def bokehK_simulation(image_path, depth_map_path, K, disp_focus, gamma=2.2):
721
+ ## depth map image can be inferenced online using following code ##
722
+ # model_dir = "/home/modules/"
723
+ # pipe = pipeline(
724
+ # task="depth-estimation",
725
+ # model="depth-anything/Depth-Anything-V2-Small-hf",
726
+ # cache_dir=model_dir,
727
+ # device=0
728
+ # )
729
+
730
+ # image_raw = Image.open(image_path)
731
+
732
+ # disp = pipe(image_raw)["depth"]
733
+ # base_name = os.path.basename(image_path)
734
+ # file_name, ext = os.path.splitext(base_name)
735
+
736
+ # disp_file_name = f"{file_name}_disp.png"
737
+ # disp.save(disp_file_name)
738
+
739
+ # disp = np.array(disp)
740
+ # disp = disp.astype(np.float32)
741
+ # disp /= 255.0
742
+
743
+ disp = np.float32(cv2.imread(depth_map_path, cv2.IMREAD_GRAYSCALE))
744
+
745
+ disp /= 255.0
746
+ disp = (disp - disp.min()) / (disp.max() - disp.min())
747
+ min_disp = np.min(disp)
748
+ max_disp = np.max(disp)
749
+
750
+ device = torch.device('cuda')
751
+
752
+ # Initialize renderer
753
+ classical_renderer = ModuleRenderScatter().to(device)
754
+
755
+ # Load image and disparity
756
+ image = cv2.imread(image_path).astype(np.float32) / 255.0
757
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
758
+
759
+ # Calculate defocus
760
+ defocus = K * (disp - disp_focus) / 10.0
761
+
762
+ # Convert to tensors and move to GPU if available
763
+ image = torch.from_numpy(image).permute(2, 0, 1).unsqueeze(0).to(device)
764
+
765
+ defocus = defocus.unsqueeze(0).unsqueeze(0).to(device)
766
+
767
+ bokeh_classical, defocus_dilate = classical_renderer(image**gamma, defocus*10.0)
768
+ bokeh_pred = bokeh_classical ** (1/gamma)
769
+ bokeh_pred = bokeh_pred.squeeze(0)
770
+ bokeh_pred = bokeh_pred.permute(1, 2, 0) # remove batch dim and change channle order
771
+ bokeh_pred = (bokeh_pred * 255).cpu().numpy()
772
+ bokeh_pred = np.round(bokeh_pred)
773
+ bokeh_pred = bokeh_pred.astype(np.float32)
774
+
775
+ return bokeh_pred
776
+
777
+
778
+
779
+
780
+ class CameraBokehK(Dataset):
781
+ def __init__(
782
+ self,
783
+ root_path,
784
+ annotation_json,
785
+ sample_n_frames=5,
786
+ sample_size=[256, 384],
787
+ is_Train=True,
788
+ ):
789
+ self.root_path = root_path
790
+ self.sample_n_frames = sample_n_frames
791
+ self.dataset = json.load(open(os.path.join(root_path, annotation_json), 'r'))
792
+
793
+ self.length = len(self.dataset)
794
+ self.is_Train = is_Train
795
+ sample_size = tuple(sample_size) if not isinstance(sample_size, int) else (sample_size, sample_size)
796
+ self.sample_size = sample_size
797
+
798
+ pixel_transforms = [transforms.Resize(sample_size),
799
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)]
800
+
801
+ self.pixel_transforms = pixel_transforms
802
+ self.tokenizer = CLIPTokenizer.from_pretrained("/home/yuan418/data/project/stable-diffusion-v1-5/", subfolder="tokenizer")
803
+ self.text_encoder = CLIPTextModel.from_pretrained("/home/yuan418/data/project/stable-diffusion-v1-5/", subfolder="text_encoder")
804
+
805
+ def load_image_reader(self, idx):
806
+ image_dict = self.dataset[idx]
807
+
808
+ image_path = os.path.join(self.root_path, image_dict['base_image_path'])
809
+ depth_map_path = os.path.join(self.root_path, image_dict['depth_map_path'])
810
+
811
+ image_caption = image_dict['caption']
812
+
813
+
814
+ if self.is_Train:
815
+ bokehK_values = [random.uniform(1.0, 30.0) for _ in range(self.sample_n_frames)]
816
+ print('train bokehK values', bokehK_values)
817
+
818
+ else:
819
+ bokehK_list_str = image_dict['bokehK_list']
820
+ bokehK_values = json.loads(bokehK_list_str)
821
+ print('validation bokehK_values', bokehK_values)
822
+
823
+ bokehK_values = torch.tensor(bokehK_values).unsqueeze(1)
824
+ return image_path, depth_map_path, image_caption, bokehK_values
825
+
826
+
827
+ def get_batch(self, idx):
828
+ image_path, depth_map_path, image_caption, bokehK_values = self.load_image_reader(idx)
829
+
830
+ total_frames = len(bokehK_values)
831
+ if total_frames < 3:
832
+ raise ValueError("less than 3 frames")
833
+
834
+ # Generate prompts for each bokehK value and append bokehK information to caption
835
+ prompts = []
836
+ for bb in bokehK_values:
837
+ prompt = f"<bokeh kernel size: {bb.item()}>"
838
+ prompts.append(prompt)
839
+
840
+ # Tokenize prompts and encode to get embeddings
841
+ with torch.no_grad():
842
+ prompt_ids = self.tokenizer(
843
+ prompts, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
844
+ ).input_ids
845
+ # print('tokenizer model_max_length', self.tokenizer.model_max_length)
846
+
847
+ encoder_hidden_states = self.text_encoder(input_ids=prompt_ids).last_hidden_state # Shape: (f, sequence_length, hidden_size)
848
+
849
+ # print('encoder_hidden_states shape', encoder_hidden_states.shape)
850
+
851
+ # Calculate differences between consecutive embeddings (ignoring sequence_length)
852
+ differences = []
853
+ for i in range(1, encoder_hidden_states.size(0)):
854
+ diff = encoder_hidden_states[i] - encoder_hidden_states[i - 1]
855
+ diff = diff.unsqueeze(0)
856
+ differences.append(diff)
857
+
858
+ # Add the difference between the last and the first embedding
859
+ final_diff = encoder_hidden_states[-1] - encoder_hidden_states[0]
860
+ final_diff = final_diff.unsqueeze(0)
861
+ differences.append(final_diff)
862
+
863
+ # Concatenate differences along the batch dimension (f-1)
864
+ concatenated_differences = torch.cat(differences, dim=0)
865
+
866
+ # print('concatenated_differences shape', concatenated_differences.shape) # f 77 768
867
+
868
+ frame = concatenated_differences.size(0)
869
+
870
+ # Concatenate differences along the batch dimension (f)
871
+ concatenated_differences = torch.cat(differences, dim=0)
872
+
873
+ # Current shape: (f, 77, 768), Pad the second dimension (77) to 128
874
+ pad_length = 128 - concatenated_differences.size(1)
875
+ if pad_length > 0:
876
+ # Pad along the second dimension (77 -> 128), pad only on the right side
877
+ concatenated_differences_padded = F.pad(concatenated_differences, (0, 0, 0, pad_length))
878
+
879
+ ## ccl = contrastive camera learning ##
880
+ ccl_embedding = concatenated_differences_padded.reshape(frame, self.sample_size[0], self.sample_size[1])
881
+ ccl_embedding = ccl_embedding.unsqueeze(1)
882
+ ccl_embedding = ccl_embedding.expand(-1, 3, -1, -1)
883
+ # print('ccl_embedding shape', ccl_embedding.shape)
884
+
885
+ pixel_values = []
886
+ for bk in bokehK_values:
887
+ img_sim = bokehK_simulation(image_path, depth_map_path, bk, disp_focus=0.96, gamma=2.2)
888
+ # save_path = os.path.join(self.root_path, f"simulated_img_bokeh_{bk.item():.2f}.png")
889
+ # cv2.imwrite(save_path, img_sim)
890
+ # print(f"Saved image: {save_path}")
891
+ pixel_values.append(img_sim)
892
+
893
+ pixel_values = np.stack(pixel_values, axis=0)
894
+ pixel_values = torch.from_numpy(pixel_values).permute(0, 3, 1, 2).contiguous() / 255.
895
+
896
+ # Create bokehK embedding and concatenate it with CCL embedding
897
+ bokehK_embedding = create_bokehK_embedding(bokehK_values, self.sample_size[0], self.sample_size[1])
898
+
899
+ camera_embedding = torch.cat((bokehK_embedding, ccl_embedding), dim=1)
900
+ # print('camera_embedding shape', camera_embedding.shape)
901
+
902
+ return pixel_values, image_caption, camera_embedding, bokehK_values
903
+
904
+ def __len__(self):
905
+ return self.length
906
+
907
+ def __getitem__(self, idx):
908
+ while True:
909
+ try:
910
+ video, video_caption, camera_embedding, bokehK_values = self.get_batch(idx)
911
+ break
912
+ except Exception as e:
913
+ idx = random.randint(0, self.length - 1)
914
+
915
+ for transform in self.pixel_transforms:
916
+ video = transform(video)
917
+
918
+ sample = dict(pixel_values=video, text=video_caption, camera_embedding=camera_embedding, bokehK_values=bokehK_values)
919
+
920
+ return sample
921
+
922
+
923
+
924
+ def test_camera_bokehK_dataset():
925
+ root_path = '/home/yuan418/data/project/camera_dataset/camera_bokehK/'
926
+ annotation_json = 'annotations/inference.json'
927
+
928
+ print('------------------')
929
+ dataset = CameraBokehK(
930
+ root_path=root_path,
931
+ annotation_json=annotation_json,
932
+ sample_n_frames=4,
933
+ sample_size=[256, 384],
934
+ is_Train=False,
935
+ )
936
+
937
+ # choose one sample for testing
938
+ idx = 1
939
+ sample = dataset[idx]
940
+
941
+ pixel_values = sample['pixel_values']
942
+ text = sample['text']
943
+ camera_embedding = sample['camera_embedding']
944
+ print(f"Pixel values shape: {pixel_values.shape}")
945
+ print(f"Text: {text}")
946
+ print(f"camera embedding shape: {camera_embedding.shape}")
947
+
948
+
949
+ if __name__ == "__main__":
950
+ test_camera_bokehK_dataset()
genphoto/models/attention.py CHANGED
@@ -1,3 +1,136 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:82630247828d56f38b979a4a7b9bc12290ada3a1ce5be1d6153d07dbe4baaaa0
3
- size 5313
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
10
+ from diffusers.models.modeling_utils import ModelMixin
11
+ from diffusers.utils import BaseOutput
12
+ from diffusers.models.attention import BasicTransformerBlock
13
+ from einops import rearrange, repeat
14
+
15
+
16
+ @dataclass
17
+ class Transformer3DModelOutput(BaseOutput):
18
+ sample: torch.FloatTensor
19
+
20
+
21
+ class Transformer3DModel(ModelMixin, ConfigMixin):
22
+ @register_to_config
23
+ def __init__(
24
+ self,
25
+ num_attention_heads: int = 16,
26
+ attention_head_dim: int = 88,
27
+ in_channels: Optional[int] = None,
28
+ num_layers: int = 1,
29
+ dropout: float = 0.0,
30
+ norm_num_groups: int = 32,
31
+ cross_attention_dim: Optional[int] = None,
32
+ attention_bias: bool = False,
33
+ activation_fn: str = "geglu",
34
+ num_embeds_ada_norm: Optional[int] = None,
35
+ use_linear_projection: bool = False,
36
+ only_cross_attention: bool = False,
37
+ upcast_attention: bool = False,
38
+ norm_type: str = "layer_norm",
39
+ norm_elementwise_affine: bool = True,
40
+ ):
41
+ super().__init__()
42
+ self.use_linear_projection = use_linear_projection
43
+ self.num_attention_heads = num_attention_heads
44
+ self.attention_head_dim = attention_head_dim
45
+ inner_dim = num_attention_heads * attention_head_dim
46
+
47
+ # Define input layers
48
+ self.in_channels = in_channels
49
+
50
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
51
+ if use_linear_projection:
52
+ self.proj_in = nn.Linear(in_channels, inner_dim)
53
+ else:
54
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
55
+
56
+ # Define transformers blocks
57
+ self.transformer_blocks = nn.ModuleList(
58
+ [
59
+ BasicTransformerBlock(
60
+ inner_dim,
61
+ num_attention_heads,
62
+ attention_head_dim,
63
+ dropout=dropout,
64
+ cross_attention_dim=cross_attention_dim,
65
+ activation_fn=activation_fn,
66
+ num_embeds_ada_norm=num_embeds_ada_norm,
67
+ attention_bias=attention_bias,
68
+ only_cross_attention=only_cross_attention,
69
+ upcast_attention=upcast_attention,
70
+ norm_type=norm_type,
71
+ norm_elementwise_affine=norm_elementwise_affine,
72
+ )
73
+ for d in range(num_layers)
74
+ ]
75
+ )
76
+
77
+ # 4. Define output layers
78
+ if use_linear_projection:
79
+ self.proj_out = nn.Linear(in_channels, inner_dim)
80
+ else:
81
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
82
+
83
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
84
+ # Input
85
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
86
+ batch_size, _, video_length = hidden_states.shape[:3]
87
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
88
+
89
+ if encoder_hidden_states.shape[0] == batch_size:
90
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
91
+
92
+ elif encoder_hidden_states.shape[0] == batch_size * video_length:
93
+ pass
94
+ else:
95
+ raise ValueError
96
+
97
+ batch, channel, height, weight = hidden_states.shape
98
+ residual = hidden_states
99
+
100
+ hidden_states = self.norm(hidden_states)
101
+ if not self.use_linear_projection:
102
+ hidden_states = self.proj_in(hidden_states)
103
+ inner_dim = hidden_states.shape[1]
104
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
105
+ else:
106
+ inner_dim = hidden_states.shape[1]
107
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
108
+ hidden_states = self.proj_in(hidden_states)
109
+
110
+ # Blocks
111
+ for block in self.transformer_blocks:
112
+ hidden_states = block(
113
+ hidden_states,
114
+ encoder_hidden_states=encoder_hidden_states,
115
+ timestep=timestep,
116
+ )
117
+
118
+ # Output
119
+ if not self.use_linear_projection:
120
+ hidden_states = (
121
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
122
+ )
123
+ hidden_states = self.proj_out(hidden_states)
124
+ else:
125
+ hidden_states = self.proj_out(hidden_states)
126
+ hidden_states = (
127
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
128
+ )
129
+
130
+ output = hidden_states + residual
131
+
132
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
133
+ if not return_dict:
134
+ return (output,)
135
+
136
+ return Transformer3DModelOutput(sample=output)
genphoto/models/attention_processor.py CHANGED
@@ -1,3 +1,412 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1fc36c35808aed64eb238e3dba643b51961992388dd76d945dec36760ab87557
3
- size 16681
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.init as init
4
+ import logging
5
+ from diffusers.models.lora import LoRALinearLayer
6
+ from diffusers.models.attention import Attention
7
+ from diffusers.utils import USE_PEFT_BACKEND
8
+ from typing import Optional
9
+
10
+ from einops import rearrange
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class AttnProcessor:
16
+ r"""
17
+ Default processor for performing attention-related computations.
18
+ """
19
+
20
+ def __call__(
21
+ self,
22
+ attn: Attention,
23
+ hidden_states: torch.FloatTensor,
24
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
25
+ attention_mask: Optional[torch.FloatTensor] = None,
26
+ temb: Optional[torch.FloatTensor] = None,
27
+ scale: float = 1.0,
28
+ camera_feature=None
29
+ ) -> torch.Tensor:
30
+ residual = hidden_states
31
+
32
+ args = () if USE_PEFT_BACKEND else (scale,)
33
+
34
+ if attn.spatial_norm is not None:
35
+ hidden_states = attn.spatial_norm(hidden_states, temb)
36
+
37
+ input_ndim = hidden_states.ndim
38
+
39
+ if input_ndim == 4:
40
+ batch_size, channel, height, width = hidden_states.shape
41
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
42
+
43
+ batch_size, sequence_length, _ = (
44
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
45
+ )
46
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
47
+
48
+ if attn.group_norm is not None:
49
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
50
+
51
+ query = attn.to_q(hidden_states, *args)
52
+
53
+ if encoder_hidden_states is None:
54
+ encoder_hidden_states = hidden_states
55
+ elif attn.norm_cross:
56
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
57
+
58
+ key = attn.to_k(encoder_hidden_states, *args)
59
+ value = attn.to_v(encoder_hidden_states, *args)
60
+
61
+ query = attn.head_to_batch_dim(query)
62
+ key = attn.head_to_batch_dim(key)
63
+ value = attn.head_to_batch_dim(value)
64
+
65
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
66
+ hidden_states = torch.bmm(attention_probs, value)
67
+ hidden_states = attn.batch_to_head_dim(hidden_states)
68
+
69
+ # linear proj
70
+ hidden_states = attn.to_out[0](hidden_states, *args)
71
+ # dropout
72
+ hidden_states = attn.to_out[1](hidden_states)
73
+
74
+ if input_ndim == 4:
75
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
76
+
77
+ if attn.residual_connection:
78
+ hidden_states = hidden_states + residual
79
+
80
+ hidden_states = hidden_states / attn.rescale_output_factor
81
+
82
+ return hidden_states
83
+
84
+
85
+ class LoRAAttnProcessor(nn.Module):
86
+ r"""
87
+ Default processor for performing attention-related computations.
88
+ """
89
+
90
+ def __init__(
91
+ self,
92
+ hidden_size=None,
93
+ cross_attention_dim=None,
94
+ rank=4,
95
+ network_alpha=None,
96
+ lora_scale=1.0,
97
+ ):
98
+ super().__init__()
99
+
100
+ self.rank = rank
101
+ self.lora_scale = lora_scale
102
+
103
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
104
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
105
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
106
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
107
+
108
+ def __call__(
109
+ self,
110
+ attn,
111
+ hidden_states,
112
+ encoder_hidden_states=None,
113
+ attention_mask=None,
114
+ temb=None,
115
+ camera_feature=None,
116
+ scale=None
117
+ ):
118
+ lora_scale = self.lora_scale if scale is None else scale
119
+ residual = hidden_states
120
+
121
+ if attn.spatial_norm is not None:
122
+ hidden_states = attn.spatial_norm(hidden_states, temb)
123
+
124
+ input_ndim = hidden_states.ndim
125
+
126
+ if input_ndim == 4:
127
+ batch_size, channel, height, width = hidden_states.shape
128
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
129
+
130
+ batch_size, sequence_length, _ = (
131
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
132
+ )
133
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
134
+
135
+ if attn.group_norm is not None:
136
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
137
+
138
+ query = attn.to_q(hidden_states) + lora_scale * self.to_q_lora(hidden_states)
139
+
140
+ if encoder_hidden_states is None:
141
+ encoder_hidden_states = hidden_states
142
+ elif attn.norm_cross:
143
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
144
+
145
+ key = attn.to_k(encoder_hidden_states) + lora_scale * self.to_k_lora(encoder_hidden_states)
146
+ value = attn.to_v(encoder_hidden_states) + lora_scale * self.to_v_lora(encoder_hidden_states)
147
+
148
+ query = attn.head_to_batch_dim(query)
149
+ key = attn.head_to_batch_dim(key)
150
+ value = attn.head_to_batch_dim(value)
151
+
152
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
153
+ hidden_states = torch.bmm(attention_probs, value)
154
+ hidden_states = attn.batch_to_head_dim(hidden_states)
155
+
156
+ # linear proj
157
+ hidden_states = attn.to_out[0](hidden_states) + lora_scale * self.to_out_lora(hidden_states)
158
+ # dropout
159
+ hidden_states = attn.to_out[1](hidden_states)
160
+
161
+ if input_ndim == 4:
162
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
163
+
164
+ if attn.residual_connection:
165
+ hidden_states = hidden_states + residual
166
+
167
+ hidden_states = hidden_states / attn.rescale_output_factor
168
+
169
+ return hidden_states
170
+
171
+
172
+ class CameraAdaptorAttnProcessor(nn.Module):
173
+ def __init__(self,
174
+ hidden_size, # dimension of hidden state
175
+ camera_feature_dim=None, # dimension of the camera feature
176
+ cross_attention_dim=None, # dimension of the text embedding
177
+ query_condition=False,
178
+ key_value_condition=False,
179
+ scale=1.0):
180
+ super().__init__()
181
+
182
+ self.hidden_size = hidden_size
183
+ self.camera_feature_dim = camera_feature_dim
184
+ self.cross_attention_dim = cross_attention_dim
185
+ self.scale = scale
186
+ self.query_condition = query_condition
187
+ self.key_value_condition = key_value_condition
188
+ assert hidden_size == camera_feature_dim
189
+ if self.query_condition and self.key_value_condition:
190
+ self.qkv_merge = nn.Linear(hidden_size, hidden_size)
191
+ init.zeros_(self.qkv_merge.weight)
192
+ init.zeros_(self.qkv_merge.bias)
193
+ elif self.query_condition:
194
+ self.q_merge = nn.Linear(hidden_size, hidden_size)
195
+ init.zeros_(self.q_merge.weight)
196
+ init.zeros_(self.q_merge.bias)
197
+ else:
198
+ self.kv_merge = nn.Linear(hidden_size, hidden_size)
199
+ init.zeros_(self.kv_merge.weight)
200
+ init.zeros_(self.kv_merge.bias)
201
+
202
+ def forward(self,
203
+ attn,
204
+ hidden_states,
205
+ camera_feature,
206
+ encoder_hidden_states=None,
207
+ attention_mask=None,
208
+ temb=None,
209
+ scale=None,):
210
+ assert camera_feature is not None
211
+ camera_embedding_scale = (scale or self.scale)
212
+
213
+ residual = hidden_states
214
+ if attn.spatial_norm is not None:
215
+ hidden_states = attn.spatial_norm(hidden_states, temb)
216
+
217
+ if hidden_states.dim == 5:
218
+ hidden_states = rearrange(hidden_states, 'b c f h w -> (b f) (h w) c')
219
+ elif hidden_states.ndim == 4:
220
+ hidden_states = rearrange(hidden_states, 'b c h w -> b (h w) c')
221
+ else:
222
+ assert hidden_states.ndim == 3
223
+
224
+ if self.query_condition and self.key_value_condition:
225
+ assert encoder_hidden_states is None
226
+
227
+ if encoder_hidden_states is None:
228
+ encoder_hidden_states = hidden_states
229
+
230
+ if encoder_hidden_states.ndim == 5:
231
+ encoder_hidden_states = rearrange(encoder_hidden_states, 'b c f h w -> (b f) (h w) c')
232
+ elif encoder_hidden_states.ndim == 4:
233
+ encoder_hidden_states = rearrange(encoder_hidden_states, 'b c h w -> b (h w) c')
234
+ else:
235
+ assert encoder_hidden_states.ndim == 3
236
+ if camera_feature.ndim == 5:
237
+ camera_feature = rearrange(camera_feature, "b c f h w -> (b f) (h w) c")
238
+ elif camera_feature.ndim == 4:
239
+ camera_feature = rearrange(camera_feature, "b c h w -> b (h w) c")
240
+ else:
241
+ assert camera_feature.ndim == 3
242
+
243
+ batch_size, ehs_sequence_length, _ = encoder_hidden_states.shape
244
+ attention_mask = attn.prepare_attention_mask(attention_mask, ehs_sequence_length, batch_size)
245
+
246
+ if attn.group_norm is not None:
247
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
248
+
249
+ if attn.norm_cross:
250
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
251
+
252
+ if self.query_condition and self.key_value_condition: # only self attention
253
+ query_hidden_state = self.qkv_merge(hidden_states + camera_feature) * camera_embedding_scale + hidden_states
254
+ key_value_hidden_state = query_hidden_state
255
+ elif self.query_condition:
256
+ query_hidden_state = self.q_merge(hidden_states + camera_feature) * camera_embedding_scale + hidden_states
257
+ key_value_hidden_state = encoder_hidden_states
258
+ else:
259
+ key_value_hidden_state = self.kv_merge(encoder_hidden_states + camera_feature) * camera_embedding_scale + encoder_hidden_states
260
+ query_hidden_state = hidden_states
261
+
262
+ # original attention
263
+ query = attn.to_q(query_hidden_state)
264
+ key = attn.to_k(key_value_hidden_state)
265
+ value = attn.to_v(key_value_hidden_state)
266
+
267
+ query = attn.head_to_batch_dim(query)
268
+ key = attn.head_to_batch_dim(key)
269
+ value = attn.head_to_batch_dim(value)
270
+
271
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
272
+ hidden_states = torch.bmm(attention_probs, value)
273
+ hidden_states = attn.batch_to_head_dim(hidden_states)
274
+
275
+ # linear proj
276
+ hidden_states = attn.to_out[0](hidden_states)
277
+ # dropout
278
+ hidden_states = attn.to_out[1](hidden_states)
279
+
280
+ if attn.residual_connection:
281
+ hidden_states = hidden_states + residual
282
+
283
+ hidden_states = hidden_states / attn.rescale_output_factor
284
+
285
+ return hidden_states
286
+
287
+
288
+ class LORACameraAdaptorAttnProcessor(nn.Module):
289
+ def __init__(self,
290
+ hidden_size, # dimension of hidden state
291
+ camera_feature_dim=None, # dimension of the camera feature
292
+ cross_attention_dim=None, # dimension of the text embedding
293
+ query_condition=False,
294
+ key_value_condition=False,
295
+ scale=1.0,
296
+ # lora keywords
297
+ rank=4,
298
+ network_alpha=None,
299
+ lora_scale=1.0):
300
+ super().__init__()
301
+
302
+ self.hidden_size = hidden_size
303
+ self.camera_feature_dim = camera_feature_dim
304
+ self.cross_attention_dim = cross_attention_dim
305
+ self.scale = scale
306
+ self.query_condition = query_condition
307
+ self.key_value_condition = key_value_condition
308
+ assert hidden_size == camera_feature_dim
309
+ if self.query_condition and self.key_value_condition:
310
+ self.qkv_merge = nn.Linear(hidden_size, hidden_size)
311
+ init.zeros_(self.qkv_merge.weight)
312
+ init.zeros_(self.qkv_merge.bias)
313
+ elif self.query_condition:
314
+ self.q_merge = nn.Linear(hidden_size, hidden_size)
315
+ init.zeros_(self.q_merge.weight)
316
+ init.zeros_(self.q_merge.bias)
317
+ else:
318
+ self.kv_merge = nn.Linear(hidden_size, hidden_size)
319
+ init.zeros_(self.kv_merge.weight)
320
+ init.zeros_(self.kv_merge.bias)
321
+ # lora
322
+ self.rank = rank
323
+ self.lora_scale = lora_scale
324
+ self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
325
+ self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
326
+ self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size, rank, network_alpha)
327
+ self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size, rank, network_alpha)
328
+
329
+ def __call__(self,
330
+ attn,
331
+ hidden_states,
332
+ encoder_hidden_states=None,
333
+ attention_mask=None,
334
+ temb=None,
335
+ scale=1.0,
336
+ camera_feature=None,
337
+ ):
338
+ assert camera_feature is not None
339
+ lora_scale = self.lora_scale if scale is None else scale
340
+ residual = hidden_states
341
+ if attn.spatial_norm is not None:
342
+ hidden_states = attn.spatial_norm(hidden_states, temb)
343
+
344
+ if hidden_states.dim == 5:
345
+ hidden_states = rearrange(hidden_states, 'b c f h w -> (b f) (h w) c')
346
+ elif hidden_states.ndim == 4:
347
+ hidden_states = rearrange(hidden_states, 'b c h w -> b (h w) c')
348
+ else:
349
+ assert hidden_states.ndim == 3
350
+
351
+ if self.query_condition and self.key_value_condition:
352
+ assert encoder_hidden_states is None
353
+
354
+ if encoder_hidden_states is None:
355
+ encoder_hidden_states = hidden_states
356
+
357
+ if encoder_hidden_states.ndim == 5:
358
+ encoder_hidden_states = rearrange(encoder_hidden_states, 'b c f h w -> (b f) (h w) c')
359
+ elif encoder_hidden_states.ndim == 4:
360
+ encoder_hidden_states = rearrange(encoder_hidden_states, 'b c h w -> b (h w) c')
361
+ else:
362
+ assert encoder_hidden_states.ndim == 3
363
+ if camera_feature.ndim == 5:
364
+ camera_feature = rearrange(camera_feature, "b c f h w -> (b f) (h w) c")
365
+ elif camera_feature.ndim == 4:
366
+ camera_feature = rearrange(camera_feature, "b c h w -> b (h w) c")
367
+ else:
368
+ assert camera_feature.ndim == 3
369
+
370
+ batch_size, ehs_sequence_length, _ = encoder_hidden_states.shape
371
+ attention_mask = attn.prepare_attention_mask(attention_mask, ehs_sequence_length, batch_size)
372
+
373
+ if attn.group_norm is not None:
374
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
375
+
376
+ if attn.norm_cross:
377
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
378
+
379
+ if self.query_condition and self.key_value_condition: # only self attention
380
+ query_hidden_state = self.qkv_merge(hidden_states + camera_feature) * self.scale + hidden_states
381
+ key_value_hidden_state = query_hidden_state
382
+ elif self.query_condition:
383
+ query_hidden_state = self.q_merge(hidden_states + camera_feature) * self.scale + hidden_states
384
+ key_value_hidden_state = encoder_hidden_states
385
+ else:
386
+ key_value_hidden_state = self.kv_merge(encoder_hidden_states + camera_feature) * self.scale + encoder_hidden_states
387
+ query_hidden_state = hidden_states
388
+
389
+ # original attention
390
+ query = attn.to_q(query_hidden_state) + lora_scale * self.to_q_lora(query_hidden_state)
391
+ key = attn.to_k(key_value_hidden_state) + lora_scale * self.to_k_lora(key_value_hidden_state)
392
+ value = attn.to_v(key_value_hidden_state) + lora_scale * self.to_v_lora(key_value_hidden_state)
393
+
394
+ query = attn.head_to_batch_dim(query)
395
+ key = attn.head_to_batch_dim(key)
396
+ value = attn.head_to_batch_dim(value)
397
+
398
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
399
+ hidden_states = torch.bmm(attention_probs, value)
400
+ hidden_states = attn.batch_to_head_dim(hidden_states)
401
+
402
+ # linear proj
403
+ hidden_states = attn.to_out[0](hidden_states) + lora_scale * self.to_out_lora(hidden_states)
404
+ # dropout
405
+ hidden_states = attn.to_out[1](hidden_states)
406
+
407
+ if attn.residual_connection:
408
+ hidden_states = hidden_states + residual
409
+
410
+ hidden_states = hidden_states / attn.rescale_output_factor
411
+
412
+ return hidden_states
genphoto/models/camera_adaptor.py CHANGED
@@ -1,3 +1,246 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b98af7dc452f718e7b74536412d017231a15d69933a224cd1cb9557fe5853ba5
3
- size 9775
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ from einops import rearrange
5
+ from genphoto.models.motion_module import TemporalTransformerBlock
6
+
7
+
8
+ def get_parameter_dtype(parameter: torch.nn.Module):
9
+ try:
10
+ params = tuple(parameter.parameters())
11
+ if len(params) > 0:
12
+ return params[0].dtype
13
+
14
+ buffers = tuple(parameter.buffers())
15
+ if len(buffers) > 0:
16
+ return buffers[0].dtype
17
+
18
+ except StopIteration:
19
+ # For torch.nn.DataParallel compatibility in PyTorch 1.5
20
+
21
+ def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
22
+ tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
23
+ return tuples
24
+
25
+ gen = parameter._named_members(get_members_fn=find_tensor_attributes)
26
+ first_tuple = next(gen)
27
+ return first_tuple[1].dtype
28
+
29
+
30
+ def conv_nd(dims, *args, **kwargs):
31
+ """
32
+ Create a 1D, 2D, or 3D convolution module.
33
+ """
34
+ if dims == 1:
35
+ return nn.Conv1d(*args, **kwargs)
36
+ elif dims == 2:
37
+ return nn.Conv2d(*args, **kwargs)
38
+ elif dims == 3:
39
+ return nn.Conv3d(*args, **kwargs)
40
+ raise ValueError(f"unsupported dimensions: {dims}")
41
+
42
+
43
+ def avg_pool_nd(dims, *args, **kwargs):
44
+ """
45
+ Create a 1D, 2D, or 3D average pooling module.
46
+ """
47
+ if dims == 1:
48
+ return nn.AvgPool1d(*args, **kwargs)
49
+ elif dims == 2:
50
+ return nn.AvgPool2d(*args, **kwargs)
51
+ elif dims == 3:
52
+ return nn.AvgPool3d(*args, **kwargs)
53
+ raise ValueError(f"unsupported dimensions: {dims}")
54
+
55
+
56
+ class CameraAdaptor(nn.Module):
57
+ def __init__(self, unet, camera_encoder):
58
+ super().__init__()
59
+ self.unet = unet
60
+ self.camera_encoder = camera_encoder
61
+
62
+ def forward(self, noisy_latents, timesteps, encoder_hidden_states, camera_embedding):
63
+ assert camera_embedding.ndim == 5
64
+ bs = camera_embedding.shape[0] # b c f h w
65
+ camera_embedding_features = self.camera_encoder(camera_embedding) # bf c h w
66
+ camera_embedding_features = [rearrange(x, '(b f) c h w -> b c f h w', b=bs)
67
+ for x in camera_embedding_features]
68
+ noise_pred = self.unet(noisy_latents,
69
+ timesteps,
70
+ encoder_hidden_states,
71
+ camera_embedding_features=camera_embedding_features).sample
72
+ return noise_pred
73
+
74
+
75
+ class Downsample(nn.Module):
76
+ """
77
+ A downsampling layer with an optional convolution.
78
+ :param channels: channels in the inputs and outputs.
79
+ :param use_conv: a bool determining if a convolution is applied.
80
+ :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
81
+ downsampling occurs in the inner-two dimensions.
82
+ """
83
+
84
+ def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
85
+ super().__init__()
86
+ self.channels = channels
87
+ self.out_channels = out_channels or channels
88
+ self.use_conv = use_conv
89
+ self.dims = dims
90
+ stride = 2 if dims != 3 else (1, 2, 2)
91
+ if use_conv:
92
+ self.op = conv_nd(dims, self.channels, self.out_channels, 3, stride=stride, padding=padding)
93
+ else:
94
+ assert self.channels == self.out_channels
95
+ self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
96
+
97
+ def forward(self, x):
98
+ assert x.shape[1] == self.channels
99
+ return self.op(x)
100
+
101
+
102
+ class ResnetBlock(nn.Module):
103
+
104
+ def __init__(self, in_c, out_c, down, ksize=3, sk=False, use_conv=True):
105
+ super().__init__()
106
+ ps = ksize // 2
107
+ if in_c != out_c or sk == False:
108
+ self.in_conv = nn.Conv2d(in_c, out_c, ksize, 1, ps)
109
+ else:
110
+ self.in_conv = None
111
+ self.block1 = nn.Conv2d(out_c, out_c, 3, 1, 1)
112
+ self.act = nn.ReLU()
113
+ self.block2 = nn.Conv2d(out_c, out_c, ksize, 1, ps)
114
+ if sk == False:
115
+ self.skep = nn.Conv2d(in_c, out_c, ksize, 1, ps)
116
+ else:
117
+ self.skep = None
118
+
119
+ self.down = down
120
+ if self.down == True:
121
+ self.down_opt = Downsample(in_c, use_conv=use_conv)
122
+
123
+ def forward(self, x):
124
+ if self.down == True:
125
+ x = self.down_opt(x)
126
+ if self.in_conv is not None: # edit
127
+ x = self.in_conv(x)
128
+
129
+ h = self.block1(x)
130
+ h = self.act(h)
131
+ h = self.block2(h)
132
+ if self.skep is not None:
133
+ return h + self.skep(x)
134
+ else:
135
+ return h + x
136
+
137
+
138
+ class PositionalEncoding(nn.Module):
139
+ def __init__(
140
+ self,
141
+ d_model,
142
+ dropout=0.,
143
+ max_len=32,
144
+ ):
145
+ super().__init__()
146
+ self.dropout = nn.Dropout(p=dropout)
147
+ position = torch.arange(max_len).unsqueeze(1)
148
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
149
+ pe = torch.zeros(1, max_len, d_model)
150
+ pe[0, :, 0::2, ...] = torch.sin(position * div_term)
151
+ pe[0, :, 1::2, ...] = torch.cos(position * div_term)
152
+ pe.unsqueeze_(-1).unsqueeze_(-1)
153
+ self.register_buffer('pe', pe)
154
+
155
+ def forward(self, x):
156
+ x = x + self.pe[:, :x.size(1), ...]
157
+ return self.dropout(x)
158
+
159
+
160
+ class CameraCameraEncoder(nn.Module):
161
+
162
+ def __init__(self,
163
+ downscale_factor,
164
+ channels=[320, 640, 1280, 1280],
165
+ nums_rb=3,
166
+ cin=64,
167
+ ksize=3,
168
+ sk=False,
169
+ use_conv=True,
170
+ compression_factor=1,
171
+ temporal_attention_nhead=8,
172
+ attention_block_types=("Temporal_Self", ),
173
+ temporal_position_encoding=False,
174
+ temporal_position_encoding_max_len=8,
175
+ rescale_output_factor=1.0):
176
+ super(CameraCameraEncoder, self).__init__()
177
+ self.unshuffle = nn.PixelUnshuffle(downscale_factor)
178
+ self.channels = channels
179
+ self.nums_rb = nums_rb
180
+ self.encoder_down_conv_blocks = nn.ModuleList()
181
+ self.encoder_down_attention_blocks = nn.ModuleList()
182
+ for i in range(len(channels)):
183
+ conv_layers = nn.ModuleList()
184
+ temporal_attention_layers = nn.ModuleList()
185
+ for j in range(nums_rb):
186
+ if j == 0 and i != 0:
187
+ in_dim = channels[i - 1]
188
+ out_dim = int(channels[i] / compression_factor)
189
+ conv_layer = ResnetBlock(in_dim, out_dim, down=True, ksize=ksize, sk=sk, use_conv=use_conv)
190
+ elif j == 0:
191
+ in_dim = channels[0]
192
+ out_dim = int(channels[i] / compression_factor)
193
+ conv_layer = ResnetBlock(in_dim, out_dim, down=False, ksize=ksize, sk=sk, use_conv=use_conv)
194
+ elif j == nums_rb - 1:
195
+ in_dim = channels[i] / compression_factor
196
+ out_dim = channels[i]
197
+ conv_layer = ResnetBlock(in_dim, out_dim, down=False, ksize=ksize, sk=sk, use_conv=use_conv)
198
+ else:
199
+ in_dim = int(channels[i] / compression_factor)
200
+ out_dim = int(channels[i] / compression_factor)
201
+ conv_layer = ResnetBlock(in_dim, out_dim, down=False, ksize=ksize, sk=sk, use_conv=use_conv)
202
+ temporal_attention_layer = TemporalTransformerBlock(dim=out_dim,
203
+ num_attention_heads=temporal_attention_nhead,
204
+ attention_head_dim=int(out_dim / temporal_attention_nhead),
205
+ attention_block_types=attention_block_types,
206
+ dropout=0.0,
207
+ cross_attention_dim=None,
208
+ temporal_position_encoding=temporal_position_encoding,
209
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
210
+ rescale_output_factor=rescale_output_factor)
211
+ conv_layers.append(conv_layer)
212
+ temporal_attention_layers.append(temporal_attention_layer)
213
+ self.encoder_down_conv_blocks.append(conv_layers)
214
+ self.encoder_down_attention_blocks.append(temporal_attention_layers)
215
+
216
+ self.encoder_conv_in = nn.Conv2d(cin, channels[0], 3, 1, 1)
217
+
218
+ @property
219
+ def dtype(self) -> torch.dtype:
220
+ """
221
+ `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
222
+ """
223
+ return get_parameter_dtype(self)
224
+
225
+ def forward(self, x):
226
+ # unshuffle
227
+ bs = x.shape[0]
228
+ x = rearrange(x, "b c f h w -> (b f) c h w")
229
+ x = self.unshuffle(x)
230
+ # extract features
231
+ features = []
232
+ x = self.encoder_conv_in(x)
233
+
234
+ # print('xxxx 1111 shape', x.shape)
235
+
236
+ for res_block, attention_block in zip(self.encoder_down_conv_blocks, self.encoder_down_attention_blocks):
237
+ for res_layer, attention_layer in zip(res_block, attention_block):
238
+ x = res_layer(x)
239
+ # print('xxxx 2222 shape', x.shape)
240
+ h, w = x.shape[-2:]
241
+ x = rearrange(x, '(b f) c h w -> (b h w) f c', b=bs)
242
+ x = attention_layer(x)
243
+ # print('xxxx 3333 shape', x.shape)
244
+ x = rearrange(x, '(b h w) f c -> (b f) c h w', h=h, w=w)
245
+ features.append(x)
246
+ return features
genphoto/models/ccl_embedding.py CHANGED
@@ -1,3 +1,64 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:599f8f75460a5b776dc213e624d4e7fc6080c8311d14ffe572501e46512141bf
3
- size 2564
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from transformers import DistilBertTokenizer, DistilBertModel
4
+ from torch.nn.functional import cosine_similarity
5
+
6
+ class FastLightweightTextEncoder:
7
+ def __init__(self, model_name='distilbert-base-uncased', cache_dir='/path/to/your/cache'):
8
+ self.tokenizer = DistilBertTokenizer.from_pretrained(model_name)
9
+ self.text_encoder = DistilBertModel.from_pretrained(model_name).eval().cuda()
10
+
11
+ def encode_texts(self, prompts):
12
+ # Batch processing the prompts to get their embeddings
13
+ inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True)
14
+ input_ids = inputs['input_ids'].cuda()
15
+ attention_mask = inputs['attention_mask'].cuda()
16
+
17
+ with torch.no_grad():
18
+ embeddings = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
19
+
20
+ # Normalize embeddings to get consistent vector representations
21
+ embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
22
+
23
+ # Print shape of embeddings
24
+ # print(f"Embeddings shape: {embeddings.shape}")
25
+ return embeddings
26
+
27
+ def calculate_differences(self, embeddings):
28
+ # Calculate differences between consecutive embeddings
29
+ differences = []
30
+ for i in range(1, embeddings.size(0)):
31
+ diff = embeddings[i] - embeddings[i - 1]
32
+ print('diff shape', diff.shape)
33
+ differences.append(diff.unsqueeze(0)) # Add batch dimension
34
+ print('differences shape', differences.shape)
35
+
36
+ # Concatenate differences along the batch dimension (f-1)
37
+ concatenated_differences = torch.cat(differences, dim=0) # Shape: (f-1, sequence_length, hidden_size)
38
+ return concatenated_differences
39
+
40
+ # Example usage
41
+ if __name__ == '__main__':
42
+ prompts = [
43
+ "A smiling dog. Focal length: 24mm.",
44
+ "A smiling dog. Focal length: 25mm.",
45
+ "A smiling dog. Focal length: 26mm.",
46
+ "A smiling dog. Focal length: 30mm.",
47
+ "A smiling dog. Focal length: 36mm.",
48
+ ]
49
+
50
+ # Initialize the FastLightweightTextEncoder
51
+ text_encoder = FastLightweightTextEncoder(cache_dir='/home/yuan418/lab/users/Yu/modules/')
52
+
53
+ # Encode the prompts
54
+ embeddings = text_encoder.encode_texts(prompts)
55
+ print('a')
56
+ print('embeddings', embeddings)
57
+ print('embeddings shape', embeddings.shape)
58
+
59
+ # Calculate and concatenate differences
60
+ concatenated_diffs = text_encoder.calculate_differences(embeddings)
61
+
62
+ print("Concatenated differences shape:", concatenated_diffs.shape)
63
+
64
+
genphoto/models/motion_module.py CHANGED
@@ -1,3 +1,389 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5a87e7341d4c8f441adbba3acf43b289589ed0825af8197262425ec35c708d32
3
- size 15717
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Callable, Optional
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+
8
+ from diffusers.utils import BaseOutput
9
+ from diffusers.models.attention_processor import Attention
10
+ from diffusers.models.attention import FeedForward
11
+
12
+ from typing import Dict, Any
13
+ from genphoto.models.resnet import InflatedGroupNorm
14
+ from genphoto.models.attention_processor import CameraAdaptorAttnProcessor
15
+
16
+ from einops import rearrange
17
+ import math
18
+
19
+
20
+ def zero_module(module):
21
+ # Zero out the parameters of a module and return it.
22
+ for p in module.parameters():
23
+ p.detach().zero_()
24
+ return module
25
+
26
+
27
+ @dataclass
28
+ class TemporalTransformer3DModelOutput(BaseOutput):
29
+ sample: torch.FloatTensor
30
+
31
+
32
+ def get_motion_module(
33
+ in_channels,
34
+ motion_module_type: str,
35
+ motion_module_kwargs: dict
36
+ ):
37
+ if motion_module_type == "Vanilla":
38
+ return VanillaTemporalModule(in_channels=in_channels, **motion_module_kwargs)
39
+ else:
40
+ raise ValueError
41
+
42
+
43
+ class VanillaTemporalModule(nn.Module):
44
+ def __init__(
45
+ self,
46
+ in_channels,
47
+ num_attention_heads=8,
48
+ num_transformer_block=2,
49
+ attention_block_types=("Temporal_Self",),
50
+ temporal_position_encoding=True,
51
+ temporal_position_encoding_max_len=32,
52
+ temporal_attention_dim_div=1,
53
+ cross_attention_dim=320,
54
+ zero_initialize=True,
55
+ encoder_hidden_states_query=(False, False),
56
+ attention_activation_scale=1.0,
57
+ attention_processor_kwargs: Dict = {},
58
+ causal_temporal_attention=False,
59
+ causal_temporal_attention_mask_type="",
60
+ rescale_output_factor=1.0
61
+ ):
62
+ super().__init__()
63
+
64
+ self.temporal_transformer = TemporalTransformer3DModel(
65
+ in_channels=in_channels,
66
+ num_attention_heads=num_attention_heads,
67
+ attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
68
+ num_layers=num_transformer_block,
69
+ attention_block_types=attention_block_types,
70
+ cross_attention_dim=cross_attention_dim,
71
+ temporal_position_encoding=temporal_position_encoding,
72
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
73
+ encoder_hidden_states_query=encoder_hidden_states_query,
74
+ attention_activation_scale=attention_activation_scale,
75
+ attention_processor_kwargs=attention_processor_kwargs,
76
+ causal_temporal_attention=causal_temporal_attention,
77
+ causal_temporal_attention_mask_type=causal_temporal_attention_mask_type,
78
+ rescale_output_factor=rescale_output_factor
79
+ )
80
+
81
+ if zero_initialize:
82
+ self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
83
+
84
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None,
85
+ cross_attention_kwargs: Dict[str, Any] = {}):
86
+ hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask, cross_attention_kwargs=cross_attention_kwargs)
87
+
88
+ output = hidden_states
89
+ return output
90
+
91
+
92
+ class TemporalTransformer3DModel(nn.Module):
93
+ def __init__(
94
+ self,
95
+ in_channels,
96
+ num_attention_heads,
97
+ attention_head_dim,
98
+ num_layers,
99
+ attention_block_types=("Temporal_Self", "Temporal_Self",),
100
+ dropout=0.0,
101
+ norm_num_groups=32,
102
+ cross_attention_dim=320,
103
+ activation_fn="geglu",
104
+ attention_bias=False,
105
+ upcast_attention=False,
106
+ temporal_position_encoding=False,
107
+ temporal_position_encoding_max_len=32,
108
+ encoder_hidden_states_query=(False, False),
109
+ attention_activation_scale=1.0,
110
+ attention_processor_kwargs: Dict = {},
111
+
112
+ causal_temporal_attention=None,
113
+ causal_temporal_attention_mask_type="",
114
+ rescale_output_factor=1.0
115
+ ):
116
+ super().__init__()
117
+ assert causal_temporal_attention is not None
118
+ self.causal_temporal_attention = causal_temporal_attention
119
+
120
+ assert (not causal_temporal_attention) or (causal_temporal_attention_mask_type != "")
121
+ self.causal_temporal_attention_mask_type = causal_temporal_attention_mask_type
122
+ self.causal_temporal_attention_mask = None
123
+
124
+ inner_dim = num_attention_heads * attention_head_dim
125
+
126
+ self.norm = InflatedGroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
127
+ self.proj_in = nn.Linear(in_channels, inner_dim)
128
+
129
+ self.transformer_blocks = nn.ModuleList(
130
+ [
131
+ TemporalTransformerBlock(
132
+ dim=inner_dim,
133
+ num_attention_heads=num_attention_heads,
134
+ attention_head_dim=attention_head_dim,
135
+ attention_block_types=attention_block_types,
136
+ dropout=dropout,
137
+ norm_num_groups=norm_num_groups,
138
+ cross_attention_dim=cross_attention_dim,
139
+ activation_fn=activation_fn,
140
+ attention_bias=attention_bias,
141
+ upcast_attention=upcast_attention,
142
+ temporal_position_encoding=temporal_position_encoding,
143
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
144
+ encoder_hidden_states_query=encoder_hidden_states_query,
145
+ attention_activation_scale=attention_activation_scale,
146
+ attention_processor_kwargs=attention_processor_kwargs,
147
+ rescale_output_factor=rescale_output_factor,
148
+ )
149
+ for d in range(num_layers)
150
+ ]
151
+ )
152
+ self.proj_out = nn.Linear(inner_dim, in_channels)
153
+
154
+ def get_causal_temporal_attention_mask(self, hidden_states):
155
+ batch_size, sequence_length, dim = hidden_states.shape
156
+
157
+ if self.causal_temporal_attention_mask is None or self.causal_temporal_attention_mask.shape != (
158
+ batch_size, sequence_length, sequence_length):
159
+ if self.causal_temporal_attention_mask_type == "causal":
160
+ # 1. vanilla causal mask
161
+ mask = torch.tril(torch.ones(sequence_length, sequence_length))
162
+
163
+ elif self.causal_temporal_attention_mask_type == "2-seq":
164
+ # 2. 2-seq
165
+ mask = torch.zeros(sequence_length, sequence_length)
166
+ mask[:sequence_length // 2, :sequence_length // 2] = 1
167
+ mask[-sequence_length // 2:, -sequence_length // 2:] = 1
168
+
169
+ elif self.causal_temporal_attention_mask_type == "0-prev":
170
+ # attn to the previous frame
171
+ indices = torch.arange(sequence_length)
172
+ indices_prev = indices - 1
173
+ indices_prev[0] = 0
174
+ mask = torch.zeros(sequence_length, sequence_length)
175
+ mask[:, 0] = 1.
176
+ mask[indices, indices_prev] = 1.
177
+
178
+ elif self.causal_temporal_attention_mask_type == "0":
179
+ # only attn to first frame
180
+ mask = torch.zeros(sequence_length, sequence_length)
181
+ mask[:, 0] = 1
182
+
183
+ elif self.causal_temporal_attention_mask_type == "wo-self":
184
+ indices = torch.arange(sequence_length)
185
+ mask = torch.ones(sequence_length, sequence_length)
186
+ mask[indices, indices] = 0
187
+
188
+ elif self.causal_temporal_attention_mask_type == "circle":
189
+ indices = torch.arange(sequence_length)
190
+ indices_prev = indices - 1
191
+ indices_prev[0] = 0
192
+
193
+ mask = torch.eye(sequence_length)
194
+ mask[indices, indices_prev] = 1
195
+ mask[0, -1] = 1
196
+
197
+ else:
198
+ raise ValueError
199
+
200
+ # generate attention mask fron binary values
201
+ mask = mask.masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
202
+ mask = mask.unsqueeze(0)
203
+ mask = mask.repeat(batch_size, 1, 1)
204
+
205
+ self.causal_temporal_attention_mask = mask.to(hidden_states.device)
206
+
207
+ return self.causal_temporal_attention_mask
208
+
209
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None,
210
+ cross_attention_kwargs: Dict[str, Any] = {},):
211
+ residual = hidden_states
212
+
213
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
214
+ height, width = hidden_states.shape[-2:]
215
+
216
+ hidden_states = self.norm(hidden_states)
217
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b h w) f c")
218
+ hidden_states = self.proj_in(hidden_states)
219
+
220
+ attention_mask = self.get_causal_temporal_attention_mask(
221
+ hidden_states) if self.causal_temporal_attention else attention_mask
222
+
223
+ # Transformer Blocks
224
+ for block in self.transformer_blocks:
225
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states,
226
+ attention_mask=attention_mask, cross_attention_kwargs=cross_attention_kwargs)
227
+ hidden_states = self.proj_out(hidden_states)
228
+
229
+ hidden_states = rearrange(hidden_states, "(b h w) f c -> b c f h w", h=height, w=width)
230
+
231
+ output = hidden_states + residual
232
+
233
+ return output
234
+
235
+
236
+ class TemporalTransformerBlock(nn.Module):
237
+ def __init__(
238
+ self,
239
+ dim,
240
+ num_attention_heads,
241
+ attention_head_dim,
242
+ attention_block_types=("Temporal_Self", "Temporal_Self",),
243
+ dropout=0.0,
244
+ norm_num_groups=32,
245
+ cross_attention_dim=768,
246
+ activation_fn="geglu",
247
+ attention_bias=False,
248
+ upcast_attention=False,
249
+ temporal_position_encoding=False,
250
+ temporal_position_encoding_max_len=32,
251
+ encoder_hidden_states_query=(False, False),
252
+ attention_activation_scale=1.0,
253
+ attention_processor_kwargs: Dict = {},
254
+ rescale_output_factor=1.0
255
+ ):
256
+ super().__init__()
257
+
258
+ attention_blocks = []
259
+ norms = []
260
+ self.attention_block_types = attention_block_types
261
+
262
+ for block_idx, block_name in enumerate(attention_block_types):
263
+ attention_blocks.append(
264
+ TemporalSelfAttention(
265
+ attention_mode=block_name,
266
+ cross_attention_dim=cross_attention_dim if block_name in ['Temporal_Cross', 'Temporal_Camera_Adaptor'] else None,
267
+ query_dim=dim,
268
+ heads=num_attention_heads,
269
+ dim_head=attention_head_dim,
270
+ dropout=dropout,
271
+ bias=attention_bias,
272
+ upcast_attention=upcast_attention,
273
+ temporal_position_encoding=temporal_position_encoding,
274
+ temporal_position_encoding_max_len=temporal_position_encoding_max_len,
275
+ rescale_output_factor=rescale_output_factor,
276
+ )
277
+ )
278
+ norms.append(nn.LayerNorm(dim))
279
+
280
+ self.attention_blocks = nn.ModuleList(attention_blocks)
281
+ self.norms = nn.ModuleList(norms)
282
+
283
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
284
+ self.ff_norm = nn.LayerNorm(dim)
285
+
286
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, cross_attention_kwargs: Dict[str, Any] = {}):
287
+ for attention_block, norm, attention_block_type in zip(self.attention_blocks, self.norms, self.attention_block_types):
288
+ norm_hidden_states = norm(hidden_states)
289
+ hidden_states = attention_block(
290
+ norm_hidden_states,
291
+ encoder_hidden_states=norm_hidden_states if attention_block_type == 'Temporal_Self' else encoder_hidden_states,
292
+ attention_mask=attention_mask,
293
+ **cross_attention_kwargs
294
+ ) + hidden_states
295
+
296
+ hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
297
+
298
+ output = hidden_states
299
+ return output
300
+
301
+
302
+ class PositionalEncoding(nn.Module):
303
+ def __init__(
304
+ self,
305
+ d_model,
306
+ dropout=0.,
307
+ max_len=32,
308
+ ):
309
+ super().__init__()
310
+ self.dropout = nn.Dropout(p=dropout)
311
+ position = torch.arange(max_len).unsqueeze(1)
312
+ div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
313
+ pe = torch.zeros(1, max_len, d_model)
314
+ pe[0, :, 0::2] = torch.sin(position * div_term)
315
+ pe[0, :, 1::2] = torch.cos(position * div_term)
316
+ self.register_buffer('pe', pe)
317
+
318
+ def forward(self, x):
319
+ x = x + self.pe[:, :x.size(1)]
320
+ return self.dropout(x)
321
+
322
+
323
+ class TemporalSelfAttention(Attention):
324
+ def __init__(
325
+ self,
326
+ attention_mode=None,
327
+ temporal_position_encoding=False,
328
+ temporal_position_encoding_max_len=32,
329
+ rescale_output_factor=1.0,
330
+ *args, **kwargs
331
+ ):
332
+ super().__init__(*args, **kwargs)
333
+ assert attention_mode == "Temporal_Self"
334
+
335
+ self.pos_encoder = PositionalEncoding(
336
+ kwargs["query_dim"],
337
+ max_len=temporal_position_encoding_max_len
338
+ ) if temporal_position_encoding else None
339
+ self.rescale_output_factor = rescale_output_factor
340
+
341
+ def set_use_memory_efficient_attention_xformers(
342
+ self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
343
+ ):
344
+ # disable motion module efficient xformers to avoid bad results, don't know why
345
+ # TODO: fix this bug
346
+ pass
347
+
348
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
349
+ # The `Attention` class can call different attention processors / attention functions
350
+ # here we simply pass along all tensors to the selected processor class
351
+ # For standard processors that are defined here, `**cross_attention_kwargs` is empty
352
+
353
+ # add position encoding
354
+ if self.pos_encoder is not None:
355
+ hidden_states = self.pos_encoder(hidden_states)
356
+ if "camera_feature" in cross_attention_kwargs:
357
+ camera_feature = cross_attention_kwargs["camera_feature"]
358
+ if camera_feature.ndim == 5:
359
+ camera_feature = rearrange(camera_feature, "b c f h w -> (b h w) f c")
360
+ else:
361
+ assert camera_feature.ndim == 3
362
+ cross_attention_kwargs["camera_feature"] = camera_feature
363
+
364
+ if isinstance(self.processor, CameraAdaptorAttnProcessor):
365
+ return self.processor(
366
+ self,
367
+ hidden_states,
368
+ cross_attention_kwargs.pop('camera_feature'),
369
+ encoder_hidden_states=None,
370
+ attention_mask=attention_mask,
371
+ **cross_attention_kwargs,
372
+ )
373
+ elif hasattr(self.processor, "__call__"):
374
+ return self.processor.__call__(
375
+ self,
376
+ hidden_states,
377
+ encoder_hidden_states=None,
378
+ attention_mask=attention_mask,
379
+ **cross_attention_kwargs,
380
+ )
381
+ else:
382
+ return self.processor(
383
+ self,
384
+ hidden_states,
385
+ encoder_hidden_states=None,
386
+ attention_mask=attention_mask,
387
+ **cross_attention_kwargs,
388
+ )
389
+
genphoto/models/resnet.py CHANGED
@@ -1,3 +1,440 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:17d68816bfa42b445e7b3c9f6da088e08024a99b838bb1ca74a327e6a9116d50
3
- size 17833
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
2
+
3
+ from einops import rearrange, repeat
4
+ from functools import partial
5
+ from typing import Optional
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+
11
+ from diffusers.models.activations import get_activation
12
+ from diffusers.models.normalization import AdaGroupNorm
13
+ from diffusers.models.attention_processor import SpatialNorm
14
+
15
+
16
+ class InflatedConv3d(nn.Conv2d):
17
+ def forward(self, x):
18
+ video_length = x.shape[2]
19
+
20
+ x = rearrange(x, "b c f h w -> (b f) c h w")
21
+ x = super().forward(x)
22
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
23
+
24
+ return x
25
+
26
+
27
+ class InflatedGroupNorm(nn.GroupNorm):
28
+ def forward(self, x):
29
+ # return super().forward(x)
30
+
31
+ video_length = x.shape[2]
32
+
33
+ x = rearrange(x, "b c f h w -> (b f) c h w")
34
+ x = super().forward(x)
35
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
36
+
37
+ return x
38
+
39
+ def zero_module(module):
40
+ # Zero out the parameters of a module and return it.
41
+ for p in module.parameters():
42
+ p.detach().zero_()
43
+ return module
44
+
45
+
46
+ class FusionBlock2D(nn.Module):
47
+ r"""
48
+ A Resnet block.
49
+
50
+ Parameters:
51
+ in_channels (`int`): The number of channels in the input.
52
+ out_channels (`int`, *optional*, default to be `None`):
53
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
54
+ dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
55
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
56
+ groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer.
57
+ groups_out (`int`, *optional*, default to None):
58
+ The number of groups to use for the second normalization layer. if set to None, same as `groups`.
59
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
60
+ non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
61
+ time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
62
+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" or
63
+ "ada_group" for a stronger conditioning with scale and shift.
64
+ kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
65
+ [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
66
+ output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
67
+ use_in_shortcut (`bool`, *optional*, default to `True`):
68
+ If `True`, add a 1x1 nn.conv2d layer for skip-connection.
69
+ up (`bool`, *optional*, default to `False`): If `True`, add an upsample layer.
70
+ down (`bool`, *optional*, default to `False`): If `True`, add a downsample layer.
71
+ conv_shortcut_bias (`bool`, *optional*, default to `True`): If `True`, adds a learnable bias to the
72
+ `conv_shortcut` output.
73
+ conv_2d_out_channels (`int`, *optional*, default to `None`): the number of channels in the output.
74
+ If None, same as `out_channels`.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ *,
80
+ in_channels,
81
+ out_channels=None,
82
+ conv_shortcut=False,
83
+ dropout=0.0,
84
+ temb_channels=512,
85
+ groups=32,
86
+ groups_out=None,
87
+ pre_norm=True,
88
+ eps=1e-6,
89
+ non_linearity="swish",
90
+ skip_time_act=False,
91
+ time_embedding_norm="default", # default, scale_shift, ada_group, spatial
92
+ kernel=None,
93
+ output_scale_factor=1.0,
94
+ use_in_shortcut=None,
95
+ up=False,
96
+ down=False,
97
+ conv_shortcut_bias: bool = True,
98
+ conv_2d_out_channels: Optional[int] = None,
99
+
100
+ zero_init=True,
101
+ ):
102
+ super().__init__()
103
+ self.pre_norm = pre_norm
104
+ self.pre_norm = True
105
+
106
+ in_channels = in_channels * 2
107
+ self.in_channels = in_channels
108
+
109
+ out_channels = in_channels * 3 if out_channels is None else out_channels * 3
110
+ self.out_channels = out_channels
111
+
112
+ self.use_conv_shortcut = conv_shortcut
113
+ self.up = up
114
+ self.down = down
115
+ self.output_scale_factor = output_scale_factor
116
+ self.time_embedding_norm = time_embedding_norm
117
+ self.skip_time_act = skip_time_act
118
+
119
+ if groups_out is None:
120
+ groups_out = groups
121
+
122
+ if self.time_embedding_norm == "ada_group":
123
+ self.norm1 = AdaGroupNorm(temb_channels, in_channels, groups, eps=eps)
124
+ elif self.time_embedding_norm == "spatial":
125
+ self.norm1 = SpatialNorm(in_channels, temb_channels)
126
+ else:
127
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
128
+
129
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
130
+
131
+ if temb_channels is not None:
132
+ if self.time_embedding_norm == "default":
133
+ self.time_emb_proj = torch.nn.Linear(temb_channels, out_channels)
134
+ elif self.time_embedding_norm == "scale_shift":
135
+ self.time_emb_proj = torch.nn.Linear(temb_channels, 2 * out_channels)
136
+ elif self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
137
+ self.time_emb_proj = None
138
+ else:
139
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
140
+ else:
141
+ self.time_emb_proj = None
142
+
143
+ if self.time_embedding_norm == "ada_group":
144
+ self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
145
+ elif self.time_embedding_norm == "spatial":
146
+ self.norm2 = SpatialNorm(out_channels, temb_channels)
147
+ else:
148
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
149
+
150
+ self.dropout = torch.nn.Dropout(dropout)
151
+ conv_2d_out_channels = conv_2d_out_channels or out_channels
152
+ self.conv2 = torch.nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0)
153
+
154
+ self.nonlinearity = get_activation(non_linearity)
155
+
156
+ self.upsample = self.downsample = None
157
+ if self.up:
158
+ if kernel == "fir":
159
+ fir_kernel = (1, 3, 3, 1)
160
+ self.upsample = lambda x: upsample_2d(x, kernel=fir_kernel)
161
+ elif kernel == "sde_vp":
162
+ self.upsample = partial(F.interpolate, scale_factor=2.0, mode="nearest")
163
+ else:
164
+ self.upsample = Upsample2D(in_channels, use_conv=False)
165
+ elif self.down:
166
+ if kernel == "fir":
167
+ fir_kernel = (1, 3, 3, 1)
168
+ self.downsample = lambda x: downsample_2d(x, kernel=fir_kernel)
169
+ elif kernel == "sde_vp":
170
+ self.downsample = partial(F.avg_pool2d, kernel_size=2, stride=2)
171
+ else:
172
+ self.downsample = Downsample2D(in_channels, use_conv=False, padding=1, name="op")
173
+
174
+ self.use_in_shortcut = self.in_channels != conv_2d_out_channels if use_in_shortcut is None else use_in_shortcut
175
+
176
+ self.conv_shortcut = None
177
+ if self.use_in_shortcut:
178
+ self.conv_shortcut = torch.nn.Conv2d(
179
+ in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
180
+ )
181
+
182
+ conv_out = torch.nn.Conv2d(
183
+ conv_2d_out_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0,
184
+ )
185
+ self.conv_out = zero_module(conv_out) if zero_init else conv_out
186
+
187
+ def forward(self, init_hidden_state, post_hidden_states, temb):
188
+ # init_hidden_state: b c 1 h w
189
+ # post_hidden_states: b c (f-1) h w
190
+
191
+ video_length = post_hidden_states.shape[2]
192
+ repeated_init_hidden_state = repeat(init_hidden_state, "b c f h w -> b c (n f) h w", n=video_length)
193
+
194
+ hidden_states = torch.cat([repeated_init_hidden_state, post_hidden_states], dim=1)
195
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
196
+ input_tensor = hidden_states
197
+
198
+ if temb.shape[0] != input_tensor.shape[0]:
199
+ temb = repeat(temb, "b c -> (b n) c", n=input_tensor.shape[0] // temb.shape[0])
200
+ assert temb.shape[0] == input_tensor.shape[0], f"{temb.shape}, {input_tensor.shape}"
201
+
202
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
203
+ hidden_states = self.norm1(hidden_states, temb)
204
+ else:
205
+ hidden_states = self.norm1(hidden_states)
206
+
207
+ hidden_states = self.nonlinearity(hidden_states)
208
+
209
+ if self.upsample is not None:
210
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
211
+ if hidden_states.shape[0] >= 64:
212
+ input_tensor = input_tensor.contiguous()
213
+ hidden_states = hidden_states.contiguous()
214
+ input_tensor = self.upsample(input_tensor)
215
+ hidden_states = self.upsample(hidden_states)
216
+ elif self.downsample is not None:
217
+ input_tensor = self.downsample(input_tensor)
218
+ hidden_states = self.downsample(hidden_states)
219
+
220
+ hidden_states = self.conv1(hidden_states)
221
+
222
+ if self.time_emb_proj is not None:
223
+ if not self.skip_time_act:
224
+ temb = self.nonlinearity(temb)
225
+ temb = self.time_emb_proj(temb)[:, :, None, None]
226
+
227
+ if temb is not None and self.time_embedding_norm == "default":
228
+ hidden_states = hidden_states + temb
229
+
230
+ if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
231
+ hidden_states = self.norm2(hidden_states, temb)
232
+ else:
233
+ hidden_states = self.norm2(hidden_states)
234
+
235
+ if temb is not None and self.time_embedding_norm == "scale_shift":
236
+ scale, shift = torch.chunk(temb, 2, dim=1)
237
+ hidden_states = hidden_states * (1 + scale) + shift
238
+
239
+ hidden_states = self.nonlinearity(hidden_states)
240
+
241
+ hidden_states = self.dropout(hidden_states)
242
+ hidden_states = self.conv2(hidden_states)
243
+
244
+ if self.conv_shortcut is not None:
245
+ input_tensor = self.conv_shortcut(input_tensor)
246
+
247
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
248
+
249
+ output_tensor = self.conv_out(output_tensor)
250
+
251
+ output_tensor = rearrange(output_tensor, "(b f) c h w -> b c f h w", f=video_length)
252
+ scale_1, scale_2, shift = output_tensor.chunk(3, dim=1)
253
+
254
+ # output_tensor = (1 + scale_1) * repeated_init_hidden_state + scale_2 * post_hidden_states + shift
255
+ output_tensor = scale_1 * repeated_init_hidden_state + (1 + scale_2) * post_hidden_states + shift
256
+
257
+ return output_tensor
258
+
259
+ class Upsample3D(nn.Module):
260
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
261
+ super().__init__()
262
+ self.channels = channels
263
+ self.out_channels = out_channels or channels
264
+ self.use_conv = use_conv
265
+ self.use_conv_transpose = use_conv_transpose
266
+ self.name = name
267
+
268
+ conv = None
269
+ if use_conv_transpose:
270
+ raise NotImplementedError
271
+ elif use_conv:
272
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
273
+
274
+ def forward(self, hidden_states, output_size=None):
275
+ assert hidden_states.shape[1] == self.channels
276
+
277
+ if self.use_conv_transpose:
278
+ raise NotImplementedError
279
+
280
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
281
+ dtype = hidden_states.dtype
282
+ if dtype == torch.bfloat16:
283
+ hidden_states = hidden_states.to(torch.float32)
284
+
285
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
286
+ if hidden_states.shape[0] >= 64:
287
+ hidden_states = hidden_states.contiguous()
288
+
289
+ # if `output_size` is passed we force the interpolation output
290
+ # size and do not make use of `scale_factor=2`
291
+ if output_size is None:
292
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
293
+ else:
294
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
295
+
296
+ # If the input is bfloat16, we cast back to bfloat16
297
+ if dtype == torch.bfloat16:
298
+ hidden_states = hidden_states.to(dtype)
299
+
300
+ # if self.use_conv:
301
+ # if self.name == "conv":
302
+ # hidden_states = self.conv(hidden_states)
303
+ # else:
304
+ # hidden_states = self.Conv2d_0(hidden_states)
305
+ hidden_states = self.conv(hidden_states)
306
+
307
+ return hidden_states
308
+
309
+
310
+ class Downsample3D(nn.Module):
311
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
312
+ super().__init__()
313
+ self.channels = channels
314
+ self.out_channels = out_channels or channels
315
+ self.use_conv = use_conv
316
+ self.padding = padding
317
+ stride = 2
318
+ self.name = name
319
+
320
+ if use_conv:
321
+ self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
322
+ else:
323
+ raise NotImplementedError
324
+
325
+ def forward(self, hidden_states):
326
+ assert hidden_states.shape[1] == self.channels
327
+ if self.use_conv and self.padding == 0:
328
+ raise NotImplementedError
329
+
330
+ assert hidden_states.shape[1] == self.channels
331
+ hidden_states = self.conv(hidden_states)
332
+
333
+ return hidden_states
334
+
335
+
336
+ class ResnetBlock3D(nn.Module):
337
+ def __init__(
338
+ self,
339
+ *,
340
+ in_channels,
341
+ out_channels=None,
342
+ conv_shortcut=False,
343
+ dropout=0.0,
344
+ temb_channels=512,
345
+ groups=32,
346
+ groups_out=None,
347
+ pre_norm=True,
348
+ eps=1e-6,
349
+ non_linearity="swish",
350
+ time_embedding_norm="default",
351
+ output_scale_factor=1.0,
352
+ use_in_shortcut=None,
353
+ ):
354
+ super().__init__()
355
+ self.pre_norm = pre_norm
356
+ self.pre_norm = True
357
+ self.in_channels = in_channels
358
+ out_channels = in_channels if out_channels is None else out_channels
359
+ self.out_channels = out_channels
360
+ self.use_conv_shortcut = conv_shortcut
361
+ self.time_embedding_norm = time_embedding_norm
362
+ self.output_scale_factor = output_scale_factor
363
+
364
+ if groups_out is None:
365
+ groups_out = groups
366
+
367
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
368
+
369
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
370
+
371
+ if temb_channels is not None:
372
+ if self.time_embedding_norm == "default":
373
+ time_emb_proj_out_channels = out_channels
374
+ elif self.time_embedding_norm == "scale_shift":
375
+ time_emb_proj_out_channels = out_channels * 2
376
+ else:
377
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
378
+
379
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
380
+ else:
381
+ self.time_emb_proj = None
382
+
383
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
384
+ self.dropout = torch.nn.Dropout(dropout)
385
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
386
+
387
+ if non_linearity == "swish":
388
+ self.nonlinearity = lambda x: F.silu(x)
389
+ elif non_linearity == "mish":
390
+ self.nonlinearity = Mish()
391
+ elif non_linearity == "silu":
392
+ self.nonlinearity = nn.SiLU()
393
+
394
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
395
+
396
+ self.conv_shortcut = None
397
+ if self.use_in_shortcut:
398
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
399
+
400
+ def forward(self, input_tensor, temb):
401
+ # input: b c f h w
402
+
403
+ hidden_states = input_tensor
404
+
405
+ video_length = hidden_states.shape[2]
406
+ emb = repeat(emb, "b c -> (b f) c", f=video_length)
407
+
408
+ hidden_states = self.norm1(hidden_states)
409
+ hidden_states = self.nonlinearity(hidden_states)
410
+
411
+ hidden_states = self.conv1(hidden_states)
412
+
413
+ if temb is not None:
414
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
415
+
416
+ if temb is not None and self.time_embedding_norm == "default":
417
+ hidden_states = hidden_states + temb
418
+
419
+ hidden_states = self.norm2(hidden_states)
420
+
421
+ if temb is not None and self.time_embedding_norm == "scale_shift":
422
+ scale, shift = torch.chunk(temb, 2, dim=1)
423
+ hidden_states = hidden_states * (1 + scale) + shift
424
+
425
+ hidden_states = self.nonlinearity(hidden_states)
426
+
427
+ hidden_states = self.dropout(hidden_states)
428
+ hidden_states = self.conv2(hidden_states)
429
+
430
+ if self.conv_shortcut is not None:
431
+ input_tensor = self.conv_shortcut(input_tensor)
432
+
433
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
434
+
435
+ return output_tensor
436
+
437
+
438
+ class Mish(torch.nn.Module):
439
+ def forward(self, hidden_states):
440
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
genphoto/models/unet.py CHANGED
@@ -1,3 +1,1300 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:66cd2f1e572a9d63f9ff6e1dc5bbacadd02916fc60cef9505761b6470c51f08e
3
- size 61839
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
2
+ import os
3
+ import json
4
+ import safetensors
5
+ import logging
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.utils.checkpoint
9
+
10
+ from einops import repeat, rearrange
11
+ from dataclasses import dataclass
12
+ from typing import List, Optional, Tuple, Union, Dict, Any
13
+
14
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
15
+ from diffusers.models.attention_processor import AttentionProcessor
16
+
17
+ from diffusers.models.modeling_utils import ModelMixin
18
+ from diffusers.utils import BaseOutput, logging
19
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
20
+ from diffusers.models.attention_processor import LoRAAttnProcessor
21
+ from diffusers.loaders import AttnProcsLayers, UNet2DConditionLoadersMixin
22
+
23
+ from genphoto.models.unet_blocks import (
24
+ CrossAttnDownBlock3D,
25
+ CrossAttnUpBlock3D,
26
+ DownBlock3D,
27
+ UNetMidBlock3DCrossAttn,
28
+ UpBlock3D,
29
+ get_down_block,
30
+ get_up_block,
31
+ )
32
+ from genphoto.models.attention_processor import (
33
+ LORACameraAdaptorAttnProcessor,
34
+ CameraAdaptorAttnProcessor
35
+ )
36
+ from genphoto.models.attention_processor import LoRAAttnProcessor as CustomizedLoRAAttnProcessor
37
+ from genphoto.models.attention_processor import AttnProcessor as CustomizedAttnProcessor
38
+ from genphoto.models.resnet import (
39
+ InflatedConv3d,
40
+ FusionBlock2D
41
+ )
42
+
43
+ @dataclass
44
+ class UNet3DConditionOutput(BaseOutput):
45
+ sample: torch.FloatTensor
46
+
47
+
48
+ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
49
+ _supports_gradient_checkpointing = True
50
+
51
+ @register_to_config
52
+ def __init__(
53
+ self,
54
+ sample_size: Optional[int] = None,
55
+ in_channels: int = 4,
56
+ out_channels: int = 4,
57
+ center_input_sample: bool = False,
58
+ flip_sin_to_cos: bool = True,
59
+ freq_shift: int = 0,
60
+ down_block_types: Tuple[str] = (
61
+ "CrossAttnDownBlock3D",
62
+ "CrossAttnDownBlock3D",
63
+ "CrossAttnDownBlock3D",
64
+ "DownBlock3D",
65
+ ),
66
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
67
+ up_block_types: Tuple[str] = (
68
+ "UpBlock3D",
69
+ "CrossAttnUpBlock3D",
70
+ "CrossAttnUpBlock3D",
71
+ "CrossAttnUpBlock3D",
72
+ ),
73
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
74
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
75
+ layers_per_block: int = 2,
76
+ downsample_padding: int = 1,
77
+ mid_block_scale_factor: float = 1,
78
+ act_fn: str = "silu",
79
+ norm_num_groups: int = 32,
80
+ norm_eps: float = 1e-5,
81
+ cross_attention_dim: int = 1280,
82
+ attention_head_dim: Union[int, Tuple[int]] = 8,
83
+ dual_cross_attention: bool = False,
84
+ use_linear_projection: bool = False,
85
+ class_embed_type: Optional[str] = None,
86
+ addition_embed_type: Optional[str] = None,
87
+ num_class_embeds: Optional[int] = None,
88
+ upcast_attention: bool = False,
89
+ resnet_time_scale_shift: str = "default",
90
+
91
+ # Additional
92
+ use_motion_module=False,
93
+ motion_module_resolutions=(1, 2, 4, 8),
94
+ motion_module_mid_block=False,
95
+ motion_module_type=None,
96
+ motion_module_kwargs={},
97
+
98
+ # whether fuse first frame's feature
99
+ fuse_first_frame: bool = False,
100
+ ):
101
+ super().__init__()
102
+ self.logger = logging.get_logger(__name__)
103
+
104
+ self.sample_size = sample_size
105
+ time_embed_dim = block_out_channels[0] * 4
106
+
107
+ # input
108
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
109
+
110
+ # time
111
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
112
+ timestep_input_dim = block_out_channels[0]
113
+
114
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
115
+
116
+ # class embedding
117
+ if class_embed_type is None and num_class_embeds is not None:
118
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
119
+ elif class_embed_type == "timestep":
120
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
121
+ elif class_embed_type == "identity":
122
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
123
+ else:
124
+ self.class_embedding = None
125
+
126
+ self.down_blocks = nn.ModuleList([])
127
+ self.mid_block = None
128
+ self.up_blocks = nn.ModuleList([])
129
+
130
+ self.down_fusers = nn.ModuleList([])
131
+ self.mid_fuser = None
132
+ self.down_fusers.append(
133
+ FusionBlock2D(
134
+ in_channels=block_out_channels[0],
135
+ out_channels=block_out_channels[0],
136
+ temb_channels=time_embed_dim,
137
+ eps=norm_eps,
138
+ groups=norm_num_groups,
139
+ time_embedding_norm=resnet_time_scale_shift,
140
+ non_linearity=act_fn,
141
+ ) if fuse_first_frame else None
142
+ )
143
+
144
+ if isinstance(only_cross_attention, bool):
145
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
146
+
147
+ if isinstance(attention_head_dim, int):
148
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
149
+
150
+ # down
151
+ output_channel = block_out_channels[0]
152
+ for i, down_block_type in enumerate(down_block_types):
153
+ res = 2 ** i
154
+ input_channel = output_channel
155
+ output_channel = block_out_channels[i]
156
+ is_final_block = i == len(block_out_channels) - 1
157
+
158
+ down_block = get_down_block(
159
+ down_block_type,
160
+ num_layers=layers_per_block,
161
+ in_channels=input_channel,
162
+ out_channels=output_channel,
163
+ temb_channels=time_embed_dim,
164
+ add_downsample=not is_final_block,
165
+ resnet_eps=norm_eps,
166
+ resnet_act_fn=act_fn,
167
+ resnet_groups=norm_num_groups,
168
+ cross_attention_dim=cross_attention_dim,
169
+ attn_num_head_channels=attention_head_dim[i],
170
+ downsample_padding=downsample_padding,
171
+ dual_cross_attention=dual_cross_attention,
172
+ use_linear_projection=use_linear_projection,
173
+ only_cross_attention=only_cross_attention[i],
174
+ upcast_attention=upcast_attention,
175
+ resnet_time_scale_shift=resnet_time_scale_shift,
176
+
177
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
178
+ motion_module_type=motion_module_type,
179
+ motion_module_kwargs=motion_module_kwargs,
180
+ )
181
+
182
+ down_fuser = nn.ModuleList(
183
+ [
184
+ FusionBlock2D(
185
+ in_channels=output_channel,
186
+ out_channels=output_channel,
187
+ temb_channels=time_embed_dim,
188
+ eps=norm_eps,
189
+ groups=norm_num_groups,
190
+ time_embedding_norm=resnet_time_scale_shift,
191
+ non_linearity=act_fn,
192
+ ) if fuse_first_frame else None for _ in
193
+ range(layers_per_block if is_final_block else layers_per_block + 1)
194
+ ]
195
+ )
196
+
197
+ self.down_blocks.append(down_block)
198
+ self.down_fusers.append(down_fuser)
199
+
200
+ # mid
201
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
202
+ self.mid_block = UNetMidBlock3DCrossAttn(
203
+ in_channels=block_out_channels[-1],
204
+ temb_channels=time_embed_dim,
205
+ resnet_eps=norm_eps,
206
+ resnet_act_fn=act_fn,
207
+ output_scale_factor=mid_block_scale_factor,
208
+ resnet_time_scale_shift=resnet_time_scale_shift,
209
+ cross_attention_dim=cross_attention_dim,
210
+ attn_num_head_channels=attention_head_dim[-1],
211
+ resnet_groups=norm_num_groups,
212
+ dual_cross_attention=dual_cross_attention,
213
+ use_linear_projection=use_linear_projection,
214
+ upcast_attention=upcast_attention,
215
+
216
+ use_motion_module=use_motion_module and motion_module_mid_block,
217
+ motion_module_type=motion_module_type,
218
+ motion_module_kwargs=motion_module_kwargs,
219
+ )
220
+ else:
221
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
222
+
223
+ self.mid_fuser = FusionBlock2D(
224
+ in_channels=block_out_channels[-1],
225
+ out_channels=block_out_channels[-1],
226
+ temb_channels=time_embed_dim,
227
+ eps=norm_eps,
228
+ groups=norm_num_groups,
229
+ time_embedding_norm=resnet_time_scale_shift,
230
+ non_linearity=act_fn,
231
+ ) if fuse_first_frame else None
232
+
233
+ # count how many layers upsample the videos
234
+ self.num_upsamplers = 0
235
+
236
+ # up
237
+ reversed_block_out_channels = list(reversed(block_out_channels))
238
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
239
+ only_cross_attention = list(reversed(only_cross_attention))
240
+ output_channel = reversed_block_out_channels[0]
241
+ for i, up_block_type in enumerate(up_block_types):
242
+ res = 2 ** (3 - i)
243
+ is_final_block = i == len(block_out_channels) - 1
244
+
245
+ prev_output_channel = output_channel
246
+ output_channel = reversed_block_out_channels[i]
247
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
248
+
249
+ # add upsample block for all BUT final layer
250
+ if not is_final_block:
251
+ add_upsample = True
252
+ self.num_upsamplers += 1
253
+ else:
254
+ add_upsample = False
255
+
256
+ up_block = get_up_block(
257
+ up_block_type,
258
+ num_layers=layers_per_block + 1,
259
+ in_channels=input_channel,
260
+ out_channels=output_channel,
261
+ prev_output_channel=prev_output_channel,
262
+ temb_channels=time_embed_dim,
263
+ add_upsample=add_upsample,
264
+ resnet_eps=norm_eps,
265
+ resnet_act_fn=act_fn,
266
+ resnet_groups=norm_num_groups,
267
+ cross_attention_dim=cross_attention_dim,
268
+ attn_num_head_channels=reversed_attention_head_dim[i],
269
+ dual_cross_attention=dual_cross_attention,
270
+ use_linear_projection=use_linear_projection,
271
+ only_cross_attention=only_cross_attention[i],
272
+ upcast_attention=upcast_attention,
273
+ resnet_time_scale_shift=resnet_time_scale_shift,
274
+
275
+ use_motion_module=use_motion_module and (res in motion_module_resolutions),
276
+ motion_module_type=motion_module_type,
277
+ motion_module_kwargs=motion_module_kwargs,
278
+ )
279
+ self.up_blocks.append(up_block)
280
+ prev_output_channel = output_channel
281
+
282
+ # out
283
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
284
+ self.conv_act = nn.SiLU()
285
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
286
+
287
+ def set_image_layer_lora(self, image_layer_lora_rank: int = 128):
288
+ lora_attn_procs = {}
289
+ for name in self.attn_processors.keys():
290
+ self.logger.info(f"(add lora) {name}")
291
+ cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
292
+ if name.startswith("mid_block"):
293
+ hidden_size = self.config.block_out_channels[-1]
294
+ elif name.startswith("up_blocks"):
295
+ block_id = int(name[len("up_blocks.")])
296
+ hidden_size = list(reversed(self.config.block_out_channels))[block_id]
297
+ elif name.startswith("down_blocks"):
298
+ block_id = int(name[len("down_blocks.")])
299
+ hidden_size = self.config.block_out_channels[block_id]
300
+
301
+ lora_attn_procs[name] = LoRAAttnProcessor(
302
+ hidden_size=hidden_size,
303
+ cross_attention_dim=cross_attention_dim,
304
+ rank=image_layer_lora_rank if image_layer_lora_rank > 16 else hidden_size // image_layer_lora_rank,
305
+ )
306
+ self.set_attn_processor(lora_attn_procs)
307
+
308
+ lora_layers = AttnProcsLayers(self.attn_processors)
309
+ self.logger.info(f"(lora parameters): {sum(p.numel() for p in lora_layers.parameters()) / 1e6:.3f} M")
310
+ del lora_layers
311
+
312
+ def set_image_layer_lora_scale(self, lora_scale: float = 1.0):
313
+ for block in self.down_blocks: setattr(block, "lora_scale", lora_scale)
314
+ for block in self.up_blocks: setattr(block, "lora_scale", lora_scale)
315
+ setattr(self.mid_block, "lora_scale", lora_scale)
316
+
317
+ def set_motion_module_lora_scale(self, lora_scale: float = 1.0):
318
+ for block in self.down_blocks: setattr(block, "motion_lora_scale", lora_scale)
319
+ for block in self.up_blocks: setattr(block, "motion_lora_scale", lora_scale)
320
+ setattr(self.mid_block, "motion_lora_scale", lora_scale)
321
+
322
+ @property
323
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
324
+ r"""
325
+ Returns:
326
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
327
+ indexed by its weight name.
328
+ """
329
+ # set recursively
330
+ processors = {}
331
+
332
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
333
+ # filter out processors in motion module
334
+ if hasattr(module, "set_processor"):
335
+ if not "motion_modules." in name:
336
+ processors[f"{name}.processor"] = module.processor
337
+
338
+ for sub_name, child in module.named_children():
339
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
340
+
341
+ return processors
342
+
343
+ for name, module in self.named_children():
344
+ fn_recursive_add_processors(name, module, processors)
345
+
346
+ return processors
347
+
348
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
349
+ r"""
350
+ Sets the attention processor to use to compute attention.
351
+
352
+ Parameters:
353
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
354
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
355
+ for **all** `Attention` layers.
356
+
357
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
358
+ processor. This is strongly recommended when setting trainable attention processors.
359
+
360
+ """
361
+ count = len(self.attn_processors.keys())
362
+
363
+ if isinstance(processor, dict) and len(processor) != count:
364
+ raise ValueError(
365
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
366
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
367
+ )
368
+
369
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
370
+ if hasattr(module, "set_processor"):
371
+ if not "motion_modules." in name:
372
+ if not isinstance(processor, dict):
373
+ module.set_processor(processor)
374
+ else:
375
+ module.set_processor(processor.pop(f"{name}.processor"))
376
+
377
+ for sub_name, child in module.named_children():
378
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
379
+
380
+ for name, module in self.named_children():
381
+ fn_recursive_attn_processor(name, module, processor)
382
+
383
+ def set_motion_module_lora_layers(self, motion_module_lora_rank: int = 32):
384
+ lora_attn_procs = {}
385
+ for name in self.mm_attn_processors.keys():
386
+ self.logger.info(f"(add lora) {name}")
387
+ cross_attention_dim = None
388
+ if name.startswith("mid_block"):
389
+ hidden_size = self.config.block_out_channels[-1]
390
+ elif name.startswith("up_blocks"):
391
+ block_id = int(name[len("up_blocks.")])
392
+ hidden_size = list(reversed(self.config.block_out_channels))[block_id]
393
+ elif name.startswith("down_blocks"):
394
+ block_id = int(name[len("down_blocks.")])
395
+ hidden_size = self.config.block_out_channels[block_id]
396
+
397
+ lora_attn_procs[name] = LoRAAttnProcessor(
398
+ hidden_size=hidden_size,
399
+ cross_attention_dim=cross_attention_dim,
400
+ rank=motion_module_lora_rank if motion_module_lora_rank > 16 else hidden_size // motion_module_lora_rank,
401
+ )
402
+ self.set_mm_attn_processor(lora_attn_procs)
403
+
404
+ lora_layers = AttnProcsLayers(self.mm_attn_processors)
405
+ return lora_layers
406
+
407
+ @property
408
+ def mm_attn_processors(self) -> Dict[str, AttentionProcessor]:
409
+ r"""
410
+ Returns:
411
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
412
+ indexed by its weight name.
413
+ """
414
+ # set recursively
415
+ processors = {}
416
+
417
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module,
418
+ processors: Dict[str, AttentionProcessor]):
419
+ # filter out processors in motion module
420
+ if hasattr(module, "set_processor"):
421
+ if "motion_modules." in name:
422
+ processors[f"{name}.processor"] = module.processor
423
+
424
+ for sub_name, child in module.named_children():
425
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
426
+
427
+ return processors
428
+
429
+ for name, module in self.named_children():
430
+ fn_recursive_add_processors(name, module, processors)
431
+
432
+ return processors
433
+
434
+ def set_mm_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
435
+ r"""
436
+ Sets the attention processor to use to compute attention.
437
+
438
+ Parameters:
439
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
440
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
441
+ for **all** `Attention` layers.
442
+
443
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
444
+ processor. This is strongly recommended when setting trainable attention processors.
445
+
446
+ """
447
+ count = len(self.mm_attn_processors.keys())
448
+
449
+ if isinstance(processor, dict) and len(processor) != count:
450
+ raise ValueError(
451
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
452
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
453
+ )
454
+
455
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
456
+ if hasattr(module, "set_processor"):
457
+ if "motion_modules." in name:
458
+ if not isinstance(processor, dict):
459
+ module.set_processor(processor)
460
+ else:
461
+ module.set_processor(processor.pop(f"{name}.processor"))
462
+
463
+ for sub_name, child in module.named_children():
464
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
465
+
466
+ for name, module in self.named_children():
467
+ fn_recursive_attn_processor(name, module, processor)
468
+
469
+ def set_attention_slice(self, slice_size):
470
+ r"""
471
+ Enable sliced attention computation.
472
+
473
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
474
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
475
+
476
+ Args:
477
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
478
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
479
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
480
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
481
+ must be a multiple of `slice_size`.
482
+ """
483
+ sliceable_head_dims = []
484
+
485
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
486
+ if hasattr(module, "set_attention_slice"):
487
+ sliceable_head_dims.append(module.sliceable_head_dim)
488
+
489
+ for child in module.children():
490
+ fn_recursive_retrieve_slicable_dims(child)
491
+
492
+ # retrieve number of attention layers
493
+ for module in self.children():
494
+ fn_recursive_retrieve_slicable_dims(module)
495
+
496
+ num_slicable_layers = len(sliceable_head_dims)
497
+
498
+ if slice_size == "auto":
499
+ # half the attention head size is usually a good trade-off between
500
+ # speed and memory
501
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
502
+ elif slice_size == "max":
503
+ # make smallest slice possible
504
+ slice_size = num_slicable_layers * [1]
505
+
506
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
507
+
508
+ if len(slice_size) != len(sliceable_head_dims):
509
+ raise ValueError(
510
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
511
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
512
+ )
513
+
514
+ for i in range(len(slice_size)):
515
+ size = slice_size[i]
516
+ dim = sliceable_head_dims[i]
517
+ if size is not None and size > dim:
518
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
519
+
520
+ # Recursively walk through all the children.
521
+ # Any children which exposes the set_attention_slice method
522
+ # gets the message
523
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
524
+ if hasattr(module, "set_attention_slice"):
525
+ module.set_attention_slice(slice_size.pop())
526
+
527
+ for child in module.children():
528
+ fn_recursive_set_attention_slice(child, slice_size)
529
+
530
+ reversed_slice_size = list(reversed(slice_size))
531
+ for module in self.children():
532
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
533
+
534
+ def _set_gradient_checkpointing(self, module, value=False):
535
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
536
+ module.gradient_checkpointing = value
537
+
538
+ def forward(
539
+ self,
540
+ sample: torch.FloatTensor,
541
+ timestep: Union[torch.Tensor, float, int],
542
+ encoder_hidden_states: Union[torch.Tensor, List[torch.Tensor]],
543
+ class_labels: Optional[torch.Tensor] = None,
544
+ attention_mask: Optional[torch.Tensor] = None,
545
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
546
+ return_dict: bool = True,
547
+
548
+ # support controlnet
549
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
550
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
551
+
552
+ # other features
553
+ motion_module_alphas: Union[tuple, float] = 1.0,
554
+ debug: bool = False,
555
+ ) -> Union[UNet3DConditionOutput, Tuple]:
556
+
557
+ activations = {}
558
+
559
+ r"""
560
+ Args:
561
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
562
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
563
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
564
+ return_dict (`bool`, *optional*, defaults to `True`):
565
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
566
+
567
+ Returns:
568
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
569
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
570
+ returning a tuple, the first element is the sample tensor.
571
+ """
572
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
573
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
574
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
575
+ # on the fly if necessary.
576
+ default_overall_up_factor = 2 ** self.num_upsamplers
577
+
578
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
579
+ forward_upsample_size = False
580
+ upsample_size = None
581
+
582
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
583
+ self.logger.info("Forward upsample size to force interpolation output size.")
584
+ forward_upsample_size = True
585
+
586
+ # prepare attention_mask
587
+ if attention_mask is not None:
588
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
589
+ attention_mask = attention_mask.unsqueeze(1)
590
+
591
+ # center input if necessary1
592
+ if self.config.center_input_sample:
593
+ sample = 2 * sample - 1.0
594
+
595
+ # time
596
+ timesteps = timestep
597
+ if not torch.is_tensor(timesteps):
598
+ # This would be a good case for the `match` statement (Python 3.10+)
599
+ is_mps = sample.device.type == "mps"
600
+ if isinstance(timestep, float):
601
+ dtype = torch.float32 if is_mps else torch.float64
602
+ else:
603
+ dtype = torch.int32 if is_mps else torch.int64
604
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
605
+ elif len(timesteps.shape) == 0:
606
+ timesteps = timesteps[None].to(sample.device)
607
+
608
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
609
+ timesteps = timesteps.expand(sample.shape[0])
610
+
611
+ t_emb = self.time_proj(timesteps)
612
+
613
+ # timesteps does not contain any weights and will always return f32 tensors
614
+ # but time_embedding might actually be running in fp16. so we need to cast here.
615
+ # there might be better ways to encapsulate this.
616
+ t_emb = t_emb.to(dtype=self.dtype)
617
+ emb = self.time_embedding(t_emb)
618
+
619
+ if self.class_embedding is not None:
620
+ if class_labels is None:
621
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
622
+
623
+ if self.config.class_embed_type == "timestep":
624
+ class_labels = self.time_proj(class_labels)
625
+
626
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
627
+ emb = emb + class_emb
628
+
629
+ # extend encoder_hidden_states
630
+ video_length = sample.shape[2]
631
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b f) n c", f=video_length)
632
+
633
+ # emb_single = emb
634
+ # emb = repeat(emb, "b c -> (b f) c", f=video_length)
635
+
636
+ # pre-process
637
+ sample = self.conv_in(sample)
638
+ activations["conv_in_out"] = sample
639
+
640
+ # to be fused
641
+ if self.down_fusers[0] != None:
642
+ # scale, shift = self.down_fusers[0](sample[:,:,0].contiguous(), emb_single).unsqueeze(2).chunk(2, dim=1)
643
+ # sample[:,:,1:] = (1 + scale) * sample[:,:,1:].contiguous() + shift
644
+ fused_sample = self.down_fusers[0](
645
+ init_hidden_state=sample[:, :, :1].contiguous(),
646
+ post_hidden_states=sample[:, :, 1:].contiguous(),
647
+ temb=emb_single,
648
+ )
649
+ sample = torch.cat([sample[:, :, :1], fused_sample], dim=2)
650
+
651
+ activations["conv_in_fuse_out"] = sample
652
+
653
+ # down
654
+ down_block_res_samples = (sample,)
655
+
656
+ # motion module alpha
657
+ if isinstance(motion_module_alphas, float):
658
+ motion_module_alphas = (motion_module_alphas,) * 5
659
+
660
+ for downsample_block, down_fuser, motion_module_alpha in zip(self.down_blocks, self.down_fusers[1:],
661
+ motion_module_alphas[:-1]):
662
+ sample, res_samples = downsample_block(
663
+ hidden_states=sample,
664
+ temb=emb,
665
+ encoder_hidden_states=encoder_hidden_states,
666
+ attention_mask=attention_mask,
667
+ motion_module_alpha=motion_module_alpha,
668
+ cross_attention_kwargs=cross_attention_kwargs
669
+ )
670
+ # to be fused
671
+ for sample_idx, fuser in enumerate(down_fuser):
672
+ if fuser != None:
673
+ fused_sample = fuser(
674
+ init_hidden_state=res_samples[sample_idx][:, :, :1].contiguous(),
675
+ post_hidden_states=res_samples[sample_idx][:, :, 1:].contiguous(),
676
+ temb=emb_single,
677
+ )
678
+ res_samples = list(res_samples)
679
+ res_samples[sample_idx] = torch.cat([res_samples[sample_idx][:, :, :1], fused_sample], dim=2)
680
+ res_samples = tuple(res_samples)
681
+
682
+ down_block_res_samples += res_samples
683
+
684
+ # support controlnet
685
+ if down_block_additional_residuals is not None:
686
+ new_down_block_res_samples = ()
687
+
688
+ for down_block_res_sample, down_block_additional_residual in zip(
689
+ down_block_res_samples, down_block_additional_residuals
690
+ ):
691
+ if len(down_block_additional_residual.shape) == 4:
692
+ # b c h w
693
+ # if input single condition, apply it to all frames
694
+ down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
695
+ # boardcast will solve the problem
696
+ # down_block_additional_residual = repeat(down_block_additional_residual, "b c f h w -> b c (f n) h w", n=video_length)
697
+
698
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
699
+ new_down_block_res_samples += (down_block_res_sample,)
700
+
701
+ down_block_res_samples = new_down_block_res_samples
702
+
703
+ # mid
704
+ sample = self.mid_block(
705
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask,
706
+ motion_module_alpha=motion_module_alphas[-1], cross_attention_kwargs=cross_attention_kwargs
707
+ )
708
+
709
+ # mid block fuser
710
+ if self.mid_fuser != None:
711
+ fused_sample = self.mid_fuser(
712
+ init_hidden_state=sample[:, :, :1],
713
+ post_hidden_states=sample[:, :, 1:],
714
+ temb=emb_single,
715
+ )
716
+ sample = torch.cat([sample[:, :, :1], fused_sample], dim=2)
717
+
718
+ # support controlnet
719
+ if mid_block_additional_residual is not None:
720
+ if len(mid_block_additional_residual.shape) == 4:
721
+ mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
722
+ # boardcast will solve this problemq
723
+ # mid_block_additional_residual = repeat(mid_block_additional_residual, "b c f h w -> b c (f n) h w", n=video_length)
724
+
725
+ sample = sample + mid_block_additional_residual
726
+
727
+ # up
728
+ for i, (upsample_block, motion_module_alpha) in enumerate(zip(self.up_blocks, motion_module_alphas[:-1][::-1])):
729
+ is_final_block = i == len(self.up_blocks) - 1
730
+
731
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
732
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
733
+
734
+ # if we have not reached the final block and need to forward the
735
+ # upsample size, we do it here
736
+ if not is_final_block and forward_upsample_size:
737
+ upsample_size = down_block_res_samples[-1].shape[2:]
738
+
739
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
740
+ sample = upsample_block(
741
+ hidden_states=sample,
742
+ temb=emb,
743
+ res_hidden_states_tuple=res_samples,
744
+ encoder_hidden_states=encoder_hidden_states,
745
+ upsample_size=upsample_size,
746
+ attention_mask=attention_mask,
747
+ motion_module_alpha=motion_module_alpha,
748
+ cross_attention_kwargs=cross_attention_kwargs
749
+ )
750
+ else:
751
+ sample = upsample_block(
752
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size,
753
+ encoder_hidden_states=encoder_hidden_states, motion_module_alpha=motion_module_alpha,
754
+ cross_attention_kwargs=cross_attention_kwargs
755
+ )
756
+ activations["upblocks_out"] = sample
757
+
758
+ # post-process
759
+ # frame-wise normalization
760
+ sample = rearrange(sample, "b c f h w -> (b f) c h w")
761
+ sample = self.conv_norm_out(sample)
762
+ sample = rearrange(sample, "(b f) c h w -> b c f h w", f=video_length)
763
+
764
+ sample = self.conv_act(sample)
765
+ sample = self.conv_out(sample)
766
+
767
+ if (not return_dict):
768
+ return (sample,)
769
+ elif debug:
770
+ return UNet3DConditionOutput(sample=sample), activations
771
+ else:
772
+ return UNet3DConditionOutput(sample=sample)
773
+
774
+ @classmethod
775
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, unet_additional_kwargs=None, logger=None):
776
+ if logger is not None:
777
+ logger.info(f"Loading unet's pretrained weights from {pretrained_model_path} ...")
778
+
779
+ if subfolder is not None:
780
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
781
+
782
+ config_file = os.path.join(pretrained_model_path, 'config.json')
783
+ if not os.path.isfile(config_file):
784
+ raise RuntimeError(f"{config_file} does not exist")
785
+
786
+ with open(config_file, "r") as f:
787
+ config = json.load(f)
788
+
789
+ config["_class_name"] = cls.__name__
790
+ config["down_block_types"] = [
791
+ "CrossAttnDownBlock3D",
792
+ "CrossAttnDownBlock3D",
793
+ "CrossAttnDownBlock3D",
794
+ "DownBlock3D"
795
+ ]
796
+ config["up_block_types"] = [
797
+ "UpBlock3D",
798
+ "CrossAttnUpBlock3D",
799
+ "CrossAttnUpBlock3D",
800
+ "CrossAttnUpBlock3D"
801
+ ]
802
+
803
+ from diffusers.utils import SAFETENSORS_WEIGHTS_NAME
804
+
805
+ model, unused_kwargs = cls.from_config(config, return_unused_kwargs=True, **unet_additional_kwargs)
806
+ if logger is not None:
807
+ logger.info(f"please check unused kwargs in 'unet_additional_kwargs' config:")
808
+ for k, v in unused_kwargs.items():
809
+ if logger is not None:
810
+ logger.info(f"{k:50s}: {repr(v)}")
811
+
812
+ model_file = os.path.join(pretrained_model_path, SAFETENSORS_WEIGHTS_NAME)
813
+ if not os.path.isfile(model_file):
814
+ raise RuntimeError(f"{model_file} does not exist")
815
+
816
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
817
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
818
+ if logger is not None:
819
+ logger.info(f"Missing keys: {len(missing)}; Unexpected keys: {len(unexpected)};")
820
+ assert len(unexpected) == 0
821
+
822
+ params = [p.numel() if "motion_modules." in n else 0 for n, p in model.named_parameters()]
823
+ if logger is not None:
824
+ logger.info(f"Motion module parameters: {sum(params) / 1e6} M")
825
+
826
+ return model
827
+
828
+
829
+ class UNet3DConditionModelCameraCond(UNet3DConditionModel):
830
+ _supports_gradient_checkpointing = True
831
+
832
+ @classmethod
833
+ def extract_init_dict(cls, config_dict, **kwargs):
834
+ # Skip keys that were not present in the original config, so default __init__ values were used
835
+ used_defaults = config_dict.get("_use_default_values", [])
836
+ config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
837
+
838
+ # 0. Copy origin config dict
839
+ original_dict = dict(config_dict.items())
840
+
841
+ # 1. Retrieve expected config attributes from __init__ signature
842
+ expected_keys = cls._get_init_keys(cls)
843
+ expected_keys.remove("self")
844
+ super_expected_keys = cls._get_init_keys(UNet3DConditionModel)
845
+ super_expected_keys.remove("self")
846
+ # remove general kwargs if present in dict
847
+ if "kwargs" in expected_keys:
848
+ expected_keys.remove("kwargs")
849
+ if "kwargs" in super_expected_keys:
850
+ super_expected_keys.remove("kwargs")
851
+ # remove flax internal keys
852
+ if hasattr(cls, "_flax_internal_args"):
853
+ for arg in cls._flax_internal_args:
854
+ expected_keys.remove(arg)
855
+ expected_keys = expected_keys.union(super_expected_keys)
856
+
857
+ # remove private attributes
858
+ config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
859
+
860
+ # 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
861
+ init_dict = {}
862
+ for key in expected_keys:
863
+ # if config param is passed to kwarg and is present in config dict
864
+ # it should overwrite existing config dict key
865
+ if key in kwargs and key in config_dict:
866
+ config_dict[key] = kwargs.pop(key)
867
+
868
+ if key in kwargs:
869
+ # overwrite key
870
+ init_dict[key] = kwargs.pop(key)
871
+ elif key in config_dict:
872
+ # use value from config dict
873
+ init_dict[key] = config_dict.pop(key)
874
+
875
+ # 4. Give nice warning if unexpected values have been passed
876
+ if len(config_dict) > 0:
877
+ print(
878
+ f"The config attributes {config_dict} were passed to {cls.__name__}, "
879
+ "but are not expected and will be ignored. Please verify your "
880
+ f"{cls.config_name} configuration file."
881
+ )
882
+
883
+ # 6. Define unused keyword arguments
884
+ unused_kwargs = {**config_dict, **kwargs}
885
+
886
+ # 7. Define "hidden" config parameters that were saved for compatible classes
887
+ hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
888
+
889
+ return init_dict, unused_kwargs, hidden_config_dict
890
+
891
+ def __init__(self,
892
+ decoder_add_cameracond=True,
893
+ **kwargs):
894
+ super(UNet3DConditionModelCameraCond, self).__init__(**kwargs)
895
+ self.decoder_add_cameracond = decoder_add_cameracond
896
+
897
+ def set_all_attn_processor(self,
898
+ add_spatial=False,
899
+ spatial_attn_names='attn1',
900
+ add_temporal=False,
901
+ add_spatial_lora=True,
902
+ add_motion_lora=False,
903
+ temporal_attn_names='0',
904
+ camera_feature_dimensions=[320, 640, 1280, 1280],
905
+ lora_kwargs={},
906
+ motion_lora_kwargs={},
907
+ **attention_processor_kwargs):
908
+ lora_rank = lora_kwargs.pop('lora_rank')
909
+ motion_lora_rank = motion_lora_kwargs.pop('lora_rank')
910
+ spatial_attn_procs = {}
911
+ if add_spatial:
912
+ set_processor_names = spatial_attn_names.split(',')
913
+ for name in self.attn_processors.keys():
914
+ attention_name = name.split('.')[-2]
915
+ cross_attention_dim = None if attention_name == 'attn1' else self.config.cross_attention_dim
916
+ if name.startswith("mid_block"):
917
+ hidden_size = self.config.block_out_channels[-1]
918
+ block_id = -1
919
+ add_camera_adaptor = attention_name in set_processor_names
920
+ camera_feature_dim = camera_feature_dimensions[block_id] if add_camera_adaptor else None
921
+ elif name.startswith("up_blocks"):
922
+ block_id = int(name[len("up_blocks.")])
923
+ hidden_size = list(reversed(self.config.block_out_channels))[block_id]
924
+ add_camera_adaptor = attention_name in set_processor_names
925
+ camera_feature_dim = list(reversed(camera_feature_dimensions))[block_id] if add_camera_adaptor else None
926
+ else:
927
+ assert name.startswith("down_blocks")
928
+ block_id = int(name[len("down_blocks.")])
929
+ hidden_size = self.config.block_out_channels[block_id]
930
+ add_camera_adaptor = attention_name in set_processor_names
931
+ camera_feature_dim = camera_feature_dimensions[block_id] if add_camera_adaptor else None
932
+ if add_camera_adaptor and add_spatial_lora:
933
+ spatial_attn_procs[name] = LORACameraAdaptorAttnProcessor(hidden_size=hidden_size,
934
+ camera_feature_dim=camera_feature_dim,
935
+ cross_attention_dim=cross_attention_dim,
936
+ rank=lora_rank if lora_rank > 16 else hidden_size // lora_rank,
937
+ **attention_processor_kwargs,
938
+ **lora_kwargs)
939
+ elif add_camera_adaptor:
940
+ spatial_attn_procs[name] = CameraAdaptorAttnProcessor(hidden_size=hidden_size,
941
+ camera_feature_dim=camera_feature_dim,
942
+ cross_attention_dim=cross_attention_dim,
943
+ **attention_processor_kwargs)
944
+ elif add_spatial_lora:
945
+ spatial_attn_procs[name] = CustomizedLoRAAttnProcessor(hidden_size=hidden_size,
946
+ cross_attention_dim=cross_attention_dim,
947
+ rank=lora_rank if lora_rank > 16 else hidden_size // lora_rank)
948
+ else:
949
+ spatial_attn_procs[name] = CustomizedAttnProcessor()
950
+ elif (not add_spatial) and add_spatial_lora:
951
+ for name in self.attn_processors.keys():
952
+ cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
953
+ if name.startswith("mid_block"):
954
+ hidden_size = self.config.block_out_channels[-1]
955
+ elif name.startswith("up_blocks"):
956
+ block_id = int(name[len("up_blocks.")])
957
+ hidden_size = list(reversed(self.config.block_out_channels))[block_id]
958
+ elif name.startswith("down_blocks"):
959
+ block_id = int(name[len("down_blocks.")])
960
+ hidden_size = self.config.block_out_channels[block_id]
961
+
962
+ spatial_attn_procs[name] = CustomizedLoRAAttnProcessor(
963
+ hidden_size=hidden_size,
964
+ cross_attention_dim=cross_attention_dim,
965
+ rank=lora_rank if lora_rank > 16 else hidden_size // lora_rank,
966
+ )
967
+ else:
968
+ for name in self.attn_processors.keys():
969
+ spatial_attn_procs[name] = CustomizedAttnProcessor()
970
+ self.set_attn_processor(spatial_attn_procs)
971
+
972
+ mm_attn_procs = {}
973
+ if add_temporal:
974
+ set_processor_names = temporal_attn_names.split(',')
975
+ cross_attention_dim = None
976
+ for name in self.mm_attn_processors.keys():
977
+ attention_name = name.split('.')[-2]
978
+ if name.startswith("mid_block"):
979
+ hidden_size = self.config.block_out_channels[-1]
980
+ block_id = -1
981
+ add_camera_adaptor = attention_name in set_processor_names
982
+ camera_feature_dim = camera_feature_dimensions[block_id] if add_camera_adaptor else None
983
+ elif name.startswith("up_blocks"):
984
+ block_id = int(name[len("up_blocks.")])
985
+ hidden_size = list(reversed(self.config.block_out_channels))[block_id]
986
+ add_camera_adaptor = (attention_name in set_processor_names) and self.decoder_add_cameracond
987
+ camera_feature_dim = list(reversed(camera_feature_dimensions))[block_id] if add_camera_adaptor else None
988
+ elif name.startswith("down_blocks"):
989
+ block_id = int(name[len("down_blocks.")])
990
+ hidden_size = self.config.block_out_channels[block_id]
991
+ add_camera_adaptor = attention_name in set_processor_names
992
+ camera_feature_dim = camera_feature_dimensions[block_id] if add_camera_adaptor else None
993
+ if add_camera_adaptor and add_motion_lora:
994
+ mm_attn_procs[name] = LORACameraAdaptorAttnProcessor(hidden_size=hidden_size,
995
+ camera_feature_dim=camera_feature_dim,
996
+ cross_attention_dim=cross_attention_dim,
997
+ rank=motion_lora_rank if motion_lora_rank > 16 else hidden_size // motion_lora_rank,
998
+ **attention_processor_kwargs,
999
+ **motion_lora_kwargs)
1000
+ elif add_camera_adaptor:
1001
+ mm_attn_procs[name] = CameraAdaptorAttnProcessor(hidden_size=hidden_size,
1002
+ camera_feature_dim=camera_feature_dim,
1003
+ cross_attention_dim=cross_attention_dim,
1004
+ **attention_processor_kwargs)
1005
+ elif add_motion_lora:
1006
+ mm_attn_procs[name] = CustomizedLoRAAttnProcessor(hidden_size=hidden_size,
1007
+ cross_attention_dim=cross_attention_dim,
1008
+ rank=motion_lora_rank if motion_lora_rank > 16 else hidden_size // motion_lora_rank)
1009
+ else:
1010
+ mm_attn_procs[name] = CustomizedAttnProcessor()
1011
+ elif (not add_temporal) and add_motion_lora:
1012
+ for name in self.mm_attn_processors.keys():
1013
+ cross_attention_dim = None
1014
+ if name.startswith("mid_block"):
1015
+ hidden_size = self.config.block_out_channels[-1]
1016
+ elif name.startswith("up_blocks"):
1017
+ block_id = int(name[len("up_blocks.")])
1018
+ hidden_size = list(reversed(self.config.block_out_channels))[block_id]
1019
+ elif name.startswith("down_blocks"):
1020
+ block_id = int(name[len("down_blocks.")])
1021
+ hidden_size = self.config.block_out_channels[block_id]
1022
+
1023
+ mm_attn_procs[name] = CustomizedLoRAAttnProcessor(
1024
+ hidden_size=hidden_size,
1025
+ cross_attention_dim=cross_attention_dim,
1026
+ rank=motion_lora_rank if motion_lora_rank > 16 else hidden_size // motion_lora_rank,
1027
+ )
1028
+ else:
1029
+ for name in self.mm_attn_processors.keys():
1030
+ mm_attn_procs[name] = CustomizedAttnProcessor()
1031
+ self.set_mm_attn_processor(mm_attn_procs)
1032
+
1033
+ def forward(
1034
+ self,
1035
+ sample: torch.FloatTensor,
1036
+ timestep: Union[torch.Tensor, float, int],
1037
+ encoder_hidden_states: Union[torch.Tensor, List[torch.Tensor]],
1038
+ class_labels: Optional[torch.Tensor] = None,
1039
+ attention_mask: Optional[torch.Tensor] = None,
1040
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1041
+ camera_embedding_features: List[torch.Tensor] = None,
1042
+ return_dict: bool = True,
1043
+
1044
+ # support controlnet
1045
+ down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
1046
+ mid_block_additional_residual: Optional[torch.Tensor] = None,
1047
+
1048
+ # other features
1049
+ motion_module_alphas: Union[tuple, float] = 1.0,
1050
+ debug: bool = False,
1051
+ ) -> Union[UNet3DConditionOutput, Tuple]:
1052
+
1053
+ activations = {}
1054
+
1055
+ default_overall_up_factor = 2 ** self.num_upsamplers
1056
+
1057
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
1058
+ forward_upsample_size = False
1059
+ upsample_size = None
1060
+
1061
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
1062
+ self.logger.info("Forward upsample size to force interpolation output size.")
1063
+ forward_upsample_size = True
1064
+
1065
+ # prepare attention_mask
1066
+ if attention_mask is not None:
1067
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
1068
+ attention_mask = attention_mask.unsqueeze(1)
1069
+
1070
+ # center input if necessary1
1071
+ if self.config.center_input_sample:
1072
+ sample = 2 * sample - 1.0
1073
+
1074
+ # time
1075
+ timesteps = timestep
1076
+ if not torch.is_tensor(timesteps):
1077
+ # This would be a good case for the `match` statement (Python 3.10+)
1078
+ is_mps = sample.device.type == "mps"
1079
+ if isinstance(timestep, float):
1080
+ dtype = torch.float32 if is_mps else torch.float64
1081
+ else:
1082
+ dtype = torch.int32 if is_mps else torch.int64
1083
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
1084
+ elif len(timesteps.shape) == 0:
1085
+ timesteps = timesteps[None].to(sample.device)
1086
+
1087
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
1088
+ timesteps = timesteps.expand(sample.shape[0])
1089
+
1090
+ t_emb = self.time_proj(timesteps)
1091
+
1092
+ # timesteps does not contain any weights and will always return f32 tensors
1093
+ # but time_embedding might actually be running in fp16. so we need to cast here.
1094
+ # there might be better ways to encapsulate this.
1095
+ t_emb = t_emb.to(dtype=self.dtype)
1096
+ emb = self.time_embedding(t_emb)
1097
+
1098
+ if self.class_embedding is not None:
1099
+ if class_labels is None:
1100
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
1101
+
1102
+ if self.config.class_embed_type == "timestep":
1103
+ class_labels = self.time_proj(class_labels)
1104
+
1105
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
1106
+ emb = emb + class_emb
1107
+
1108
+ # extend encoder_hidden_states
1109
+ video_length = sample.shape[2]
1110
+ encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b f) n c", f=video_length)
1111
+
1112
+ # pre-process
1113
+ sample = self.conv_in(sample) # b c f h w
1114
+ activations["conv_in_out"] = sample
1115
+
1116
+ # to be fused
1117
+ if self.down_fusers[0] != None:
1118
+ # scale, shift = self.down_fusers[0](sample[:,:,0].contiguous(), emb_single).unsqueeze(2).chunk(2, dim=1)
1119
+ # sample[:,:,1:] = (1 + scale) * sample[:,:,1:].contiguous() + shift
1120
+ fused_sample = self.down_fusers[0](
1121
+ init_hidden_state=sample[:, :, :1].contiguous(),
1122
+ post_hidden_states=sample[:, :, 1:].contiguous(),
1123
+ temb=emb_single,
1124
+ )
1125
+ sample = torch.cat([sample[:, :, :1], fused_sample], dim=2)
1126
+
1127
+ activations["conv_in_fuse_out"] = sample
1128
+
1129
+ # down
1130
+ down_block_res_samples = (sample,)
1131
+
1132
+ # motion module alpha
1133
+ if isinstance(motion_module_alphas, float):
1134
+ motion_module_alphas = (motion_module_alphas,) * 5
1135
+
1136
+ for downsample_block, camera_embedding_feature, down_fuser, motion_module_alpha in zip(self.down_blocks,
1137
+ camera_embedding_features,
1138
+ self.down_fusers[1:],
1139
+ motion_module_alphas[:-1]):
1140
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
1141
+ sample, res_samples = downsample_block(
1142
+ hidden_states=sample,
1143
+ temb=emb,
1144
+ encoder_hidden_states=encoder_hidden_states,
1145
+ attention_mask=attention_mask,
1146
+ motion_module_alpha=motion_module_alpha,
1147
+ cross_attention_kwargs=cross_attention_kwargs.update({"camera_feature": camera_embedding_feature})
1148
+ if cross_attention_kwargs is not None else {"camera_feature": camera_embedding_feature},
1149
+ motion_cross_attention_kwargs={"camera_feature": camera_embedding_feature}
1150
+ )
1151
+ else:
1152
+ sample, res_samples = downsample_block(
1153
+ hidden_states=sample,
1154
+ temb=emb,
1155
+ motion_module_alpha=motion_module_alpha,
1156
+ cross_attention_kwargs=cross_attention_kwargs.update({"camera_feature": camera_embedding_feature})
1157
+ if cross_attention_kwargs is not None else {"camera_feature": camera_embedding_feature},
1158
+ motion_cross_attention_kwargs={"camera_feature": camera_embedding_feature}
1159
+ )
1160
+
1161
+ # to be fused
1162
+ for sample_idx, fuser in enumerate(down_fuser):
1163
+ if fuser != None:
1164
+ fused_sample = fuser(
1165
+ init_hidden_state=res_samples[sample_idx][:, :, :1].contiguous(),
1166
+ post_hidden_states=res_samples[sample_idx][:, :, 1:].contiguous(),
1167
+ temb=emb_single,
1168
+ )
1169
+ res_samples = list(res_samples)
1170
+ res_samples[sample_idx] = torch.cat([res_samples[sample_idx][:, :, :1], fused_sample], dim=2)
1171
+ res_samples = tuple(res_samples)
1172
+
1173
+ down_block_res_samples += res_samples
1174
+
1175
+ # support controlnet
1176
+ if down_block_additional_residuals is not None:
1177
+ new_down_block_res_samples = ()
1178
+
1179
+ for down_block_res_sample, down_block_additional_residual in zip(
1180
+ down_block_res_samples, down_block_additional_residuals
1181
+ ):
1182
+ if len(down_block_additional_residual.shape) == 4:
1183
+ # b c h w
1184
+ # if input single condition, apply it to all frames
1185
+ down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
1186
+ # boardcast will solve the problem
1187
+ # down_block_additional_residual = repeat(down_block_additional_residual, "b c f h w -> b c (f n) h w", n=video_length)
1188
+
1189
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
1190
+ new_down_block_res_samples += (down_block_res_sample,)
1191
+
1192
+ down_block_res_samples = new_down_block_res_samples
1193
+
1194
+ # mid
1195
+ sample = self.mid_block(
1196
+ sample,
1197
+ emb,
1198
+ encoder_hidden_states=encoder_hidden_states,
1199
+ attention_mask=attention_mask,
1200
+ motion_module_alpha=motion_module_alphas[-1],
1201
+ cross_attention_kwargs=cross_attention_kwargs.update({"camera_feature": camera_embedding_features[-1]})
1202
+ if cross_attention_kwargs is not None else {"camera_feature": camera_embedding_features[-1]},
1203
+ motion_cross_attention_kwargs={"camera_feature": camera_embedding_features[-1]}
1204
+ )
1205
+
1206
+ # mid block fuser
1207
+ if self.mid_fuser != None:
1208
+ fused_sample = self.mid_fuser(
1209
+ init_hidden_state=sample[:, :, :1],
1210
+ post_hidden_states=sample[:, :, 1:],
1211
+ temb=emb_single,
1212
+ )
1213
+ sample = torch.cat([sample[:, :, :1], fused_sample], dim=2)
1214
+
1215
+ # support controlnet
1216
+ if mid_block_additional_residual is not None:
1217
+ if len(mid_block_additional_residual.shape) == 4:
1218
+ mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
1219
+ # boardcast will solve this problemq
1220
+ # mid_block_additional_residual = repeat(mid_block_additional_residual, "b c f h w -> b c (f n) h w", n=video_length)
1221
+
1222
+ sample = sample + mid_block_additional_residual
1223
+
1224
+ # up
1225
+ for i, (upsample_block, motion_module_alpha) in enumerate(zip(self.up_blocks, motion_module_alphas[:-1][::-1])):
1226
+ is_final_block = i == len(self.up_blocks) - 1
1227
+ camera_embedding_feature = camera_embedding_features[-(i+1)] if self.decoder_add_cameracond else None
1228
+
1229
+ res_samples = down_block_res_samples[-len(upsample_block.resnets):]
1230
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
1231
+
1232
+ # if we have not reached the final block and need to forward the
1233
+ # upsample size, we do it here
1234
+ if not is_final_block and forward_upsample_size:
1235
+ upsample_size = down_block_res_samples[-1].shape[2:]
1236
+
1237
+ if self.decoder_add_cameracond:
1238
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1239
+ sample = upsample_block(
1240
+ hidden_states=sample,
1241
+ temb=emb,
1242
+ res_hidden_states_tuple=res_samples,
1243
+ encoder_hidden_states=encoder_hidden_states,
1244
+ upsample_size=upsample_size,
1245
+ attention_mask=attention_mask,
1246
+ motion_module_alpha=motion_module_alpha,
1247
+ cross_attention_kwargs=cross_attention_kwargs.update({"camera_feature":camera_embedding_feature})
1248
+ if cross_attention_kwargs is not None else {"camera_feature": camera_embedding_feature},
1249
+ motion_cross_attention_kwargs={"camera_feature": camera_embedding_feature}
1250
+ )
1251
+ else:
1252
+ sample = upsample_block(
1253
+ hidden_states=sample,
1254
+ temb=emb,
1255
+ res_hidden_states_tuple=res_samples,
1256
+ upsample_size=upsample_size,
1257
+ motion_module_alpha=motion_module_alpha,
1258
+ cross_attention_kwargs=cross_attention_kwargs.update({"camera_feature": camera_embedding_feature})
1259
+ if cross_attention_kwargs is not None else {"camera_feature": camera_embedding_feature},
1260
+ motion_cross_attention_kwargs={"camera_feature": camera_embedding_feature}
1261
+ )
1262
+ else:
1263
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
1264
+ sample = upsample_block(
1265
+ hidden_states=sample,
1266
+ temb=emb,
1267
+ res_hidden_states_tuple=res_samples,
1268
+ encoder_hidden_states=encoder_hidden_states,
1269
+ upsample_size=upsample_size,
1270
+ attention_mask=attention_mask,
1271
+ motion_module_alpha=motion_module_alpha,
1272
+ cross_attention_kwargs=cross_attention_kwargs,
1273
+ )
1274
+ else:
1275
+ sample = upsample_block(
1276
+ hidden_states=sample,
1277
+ temb=emb,
1278
+ res_hidden_states_tuple=res_samples,
1279
+ upsample_size=upsample_size,
1280
+ motion_module_alpha=motion_module_alpha,
1281
+ cross_attention_kwargs=cross_attention_kwargs
1282
+ )
1283
+
1284
+ activations["upblocks_out"] = sample
1285
+
1286
+ # post-process
1287
+ # frame-wise normalization
1288
+ sample = rearrange(sample, "b c f h w -> (b f) c h w")
1289
+ sample = self.conv_norm_out(sample)
1290
+ sample = rearrange(sample, "(b f) c h w -> b c f h w", f=video_length)
1291
+
1292
+ sample = self.conv_act(sample)
1293
+ sample = self.conv_out(sample)
1294
+
1295
+ if (not return_dict):
1296
+ return (sample,)
1297
+ elif debug:
1298
+ return UNet3DConditionOutput(sample=sample), activations
1299
+ else:
1300
+ return UNet3DConditionOutput(sample=sample)
genphoto/models/unet_blocks.py CHANGED
@@ -1,3 +1,818 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:767e2392b19861d964d37159b591b9d489abc9a30332fb1a337694d7f3a94f28
3
- size 34808
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
2
+
3
+ import torch
4
+ from torch import nn
5
+ from einops import rearrange, repeat
6
+ from diffusers.models.resnet import Downsample2D, Upsample2D, ResnetBlock2D
7
+ from diffusers.models.transformer_2d import Transformer2DModel
8
+
9
+ from genphoto.models.motion_module import get_motion_module
10
+
11
+
12
+ def get_down_block(
13
+ down_block_type,
14
+ num_layers,
15
+ in_channels,
16
+ out_channels,
17
+ temb_channels,
18
+ add_downsample,
19
+ resnet_eps,
20
+ resnet_act_fn,
21
+ attn_num_head_channels,
22
+ resnet_groups=None,
23
+ cross_attention_dim=None,
24
+ downsample_padding=None,
25
+ dual_cross_attention=False,
26
+ use_linear_projection=False,
27
+ only_cross_attention=False,
28
+ upcast_attention=False,
29
+ resnet_time_scale_shift="default",
30
+ use_motion_module=None,
31
+ motion_module_type=None,
32
+ motion_module_kwargs=None,
33
+ ):
34
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
35
+ if down_block_type == "DownBlock3D":
36
+ return DownBlock3D(
37
+ num_layers=num_layers,
38
+ in_channels=in_channels,
39
+ out_channels=out_channels,
40
+ temb_channels=temb_channels,
41
+ add_downsample=add_downsample,
42
+ resnet_eps=resnet_eps,
43
+ resnet_act_fn=resnet_act_fn,
44
+ resnet_groups=resnet_groups,
45
+ downsample_padding=downsample_padding,
46
+ resnet_time_scale_shift=resnet_time_scale_shift,
47
+ use_motion_module=use_motion_module,
48
+ motion_module_type=motion_module_type,
49
+ motion_module_kwargs=motion_module_kwargs,
50
+ )
51
+ elif down_block_type == "CrossAttnDownBlock3D":
52
+ if cross_attention_dim is None:
53
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
54
+ return CrossAttnDownBlock3D(
55
+ num_layers=num_layers,
56
+ in_channels=in_channels,
57
+ out_channels=out_channels,
58
+ temb_channels=temb_channels,
59
+ add_downsample=add_downsample,
60
+ resnet_eps=resnet_eps,
61
+ resnet_act_fn=resnet_act_fn,
62
+ resnet_groups=resnet_groups,
63
+ downsample_padding=downsample_padding,
64
+ cross_attention_dim=cross_attention_dim,
65
+ attn_num_head_channels=attn_num_head_channels,
66
+ dual_cross_attention=dual_cross_attention,
67
+ use_linear_projection=use_linear_projection,
68
+ only_cross_attention=only_cross_attention,
69
+ upcast_attention=upcast_attention,
70
+ resnet_time_scale_shift=resnet_time_scale_shift,
71
+ use_motion_module=use_motion_module,
72
+ motion_module_type=motion_module_type,
73
+ motion_module_kwargs=motion_module_kwargs,
74
+ )
75
+ raise ValueError(f"{down_block_type} does not exist.")
76
+
77
+
78
+ def get_up_block(
79
+ up_block_type,
80
+ num_layers,
81
+ in_channels,
82
+ out_channels,
83
+ prev_output_channel,
84
+ temb_channels,
85
+ add_upsample,
86
+ resnet_eps,
87
+ resnet_act_fn,
88
+ attn_num_head_channels,
89
+ resnet_groups=None,
90
+ cross_attention_dim=None,
91
+ dual_cross_attention=False,
92
+ use_linear_projection=False,
93
+ only_cross_attention=False,
94
+ upcast_attention=False,
95
+ resnet_time_scale_shift="default",
96
+ use_motion_module=None,
97
+ motion_module_type=None,
98
+ motion_module_kwargs=None,
99
+ ):
100
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
101
+ if up_block_type == "UpBlock3D":
102
+ return UpBlock3D(
103
+ num_layers=num_layers,
104
+ in_channels=in_channels,
105
+ out_channels=out_channels,
106
+ prev_output_channel=prev_output_channel,
107
+ temb_channels=temb_channels,
108
+ add_upsample=add_upsample,
109
+ resnet_eps=resnet_eps,
110
+ resnet_act_fn=resnet_act_fn,
111
+ resnet_groups=resnet_groups,
112
+ resnet_time_scale_shift=resnet_time_scale_shift,
113
+ use_motion_module=use_motion_module,
114
+ motion_module_type=motion_module_type,
115
+ motion_module_kwargs=motion_module_kwargs,
116
+ )
117
+ elif up_block_type == "CrossAttnUpBlock3D":
118
+ if cross_attention_dim is None:
119
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
120
+ return CrossAttnUpBlock3D(
121
+ num_layers=num_layers,
122
+ in_channels=in_channels,
123
+ out_channels=out_channels,
124
+ prev_output_channel=prev_output_channel,
125
+ temb_channels=temb_channels,
126
+ add_upsample=add_upsample,
127
+ resnet_eps=resnet_eps,
128
+ resnet_act_fn=resnet_act_fn,
129
+ resnet_groups=resnet_groups,
130
+ cross_attention_dim=cross_attention_dim,
131
+ attn_num_head_channels=attn_num_head_channels,
132
+ dual_cross_attention=dual_cross_attention,
133
+ use_linear_projection=use_linear_projection,
134
+ only_cross_attention=only_cross_attention,
135
+ upcast_attention=upcast_attention,
136
+ resnet_time_scale_shift=resnet_time_scale_shift,
137
+ use_motion_module=use_motion_module,
138
+ motion_module_type=motion_module_type,
139
+ motion_module_kwargs=motion_module_kwargs,
140
+ )
141
+ raise ValueError(f"{up_block_type} does not exist.")
142
+
143
+
144
+ class UNetMidBlock3DCrossAttn(nn.Module):
145
+ def __init__(
146
+ self,
147
+ in_channels: int,
148
+ temb_channels: int,
149
+ dropout: float = 0.0,
150
+ num_layers: int = 1,
151
+ resnet_eps: float = 1e-6,
152
+ resnet_time_scale_shift: str = "default",
153
+ resnet_act_fn: str = "swish",
154
+ resnet_groups: int = 32,
155
+ resnet_pre_norm: bool = True,
156
+ attn_num_head_channels=1,
157
+ output_scale_factor=1.0,
158
+ cross_attention_dim=1280,
159
+ dual_cross_attention=False,
160
+ use_linear_projection=False,
161
+ upcast_attention=False,
162
+
163
+ use_motion_module=None,
164
+ motion_module_type=None,
165
+ motion_module_kwargs=None,
166
+ ):
167
+ super().__init__()
168
+
169
+ self.has_cross_attention = True
170
+ self.attn_num_head_channels = attn_num_head_channels
171
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
172
+
173
+ # there is always at least one resnet
174
+ resnets = [
175
+ ResnetBlock2D(
176
+ in_channels=in_channels,
177
+ out_channels=in_channels,
178
+ temb_channels=temb_channels,
179
+ eps=resnet_eps,
180
+ groups=resnet_groups,
181
+ dropout=dropout,
182
+ time_embedding_norm=resnet_time_scale_shift,
183
+ non_linearity=resnet_act_fn,
184
+ output_scale_factor=output_scale_factor,
185
+ pre_norm=resnet_pre_norm,
186
+ )
187
+ ]
188
+ attentions = []
189
+ motion_modules = []
190
+
191
+ for _ in range(num_layers):
192
+ if dual_cross_attention: raise NotImplementedError
193
+ attentions.append(
194
+ Transformer2DModel(
195
+ attn_num_head_channels,
196
+ in_channels // attn_num_head_channels,
197
+ in_channels=in_channels,
198
+ num_layers=1,
199
+ cross_attention_dim=cross_attention_dim,
200
+ norm_num_groups=resnet_groups,
201
+ use_linear_projection=use_linear_projection,
202
+ upcast_attention=upcast_attention,
203
+ )
204
+ )
205
+ motion_modules.append(
206
+ get_motion_module(
207
+ in_channels=in_channels,
208
+ motion_module_type=motion_module_type,
209
+ motion_module_kwargs=motion_module_kwargs,
210
+ ) if use_motion_module else None
211
+ )
212
+ resnets.append(
213
+ ResnetBlock2D(
214
+ in_channels=in_channels,
215
+ out_channels=in_channels,
216
+ temb_channels=temb_channels,
217
+ eps=resnet_eps,
218
+ groups=resnet_groups,
219
+ dropout=dropout,
220
+ time_embedding_norm=resnet_time_scale_shift,
221
+ non_linearity=resnet_act_fn,
222
+ output_scale_factor=output_scale_factor,
223
+ pre_norm=resnet_pre_norm,
224
+ )
225
+ )
226
+
227
+ self.attentions = nn.ModuleList(attentions)
228
+ self.resnets = nn.ModuleList(resnets)
229
+ self.motion_modules = nn.ModuleList(motion_modules) if use_motion_module else motion_modules
230
+
231
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None,
232
+ motion_module_alpha=1., cross_attention_kwargs=None, motion_cross_attention_kwargs=None):
233
+ video_length = hidden_states.shape[2]
234
+ temb_repeated = repeat(temb, "b c -> (b f) c", f=video_length)
235
+
236
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
237
+ hidden_states = self.resnets[0](hidden_states, temb_repeated)
238
+ hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length)
239
+
240
+ lora_scale = getattr(self, "lora_scale", None)
241
+ if lora_scale != None:
242
+ cross_attention_kwargs = {"scale": lora_scale}
243
+ motion_lora_scale = getattr(self, "motion_lora_scale", None)
244
+ if motion_lora_scale != None:
245
+ if motion_cross_attention_kwargs is None:
246
+ motion_cross_attention_kwargs = {"scale": motion_lora_scale}
247
+ else:
248
+ motion_cross_attention_kwargs.update({"scale": motion_lora_scale})
249
+
250
+ for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
251
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
252
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states,
253
+ cross_attention_kwargs=cross_attention_kwargs).sample
254
+ hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length)
255
+
256
+ # motion module
257
+ if motion_module is not None:
258
+ # hidden_states = motion_module_alpha * motion_module(hidden_states, temb=temb, encoder_hidden_states=encoder_hidden_states) + hidden_states
259
+ hidden_states = motion_module(hidden_states, temb=temb, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=motion_cross_attention_kwargs)
260
+
261
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
262
+ hidden_states = resnet(hidden_states, temb_repeated)
263
+ hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length)
264
+
265
+ return hidden_states
266
+
267
+
268
+ class CrossAttnDownBlock3D(nn.Module):
269
+ def __init__(
270
+ self,
271
+ in_channels: int,
272
+ out_channels: int,
273
+ temb_channels: int,
274
+ dropout: float = 0.0,
275
+ num_layers: int = 1,
276
+ resnet_eps: float = 1e-6,
277
+ resnet_time_scale_shift: str = "default",
278
+ resnet_act_fn: str = "swish",
279
+ resnet_groups: int = 32,
280
+ resnet_pre_norm: bool = True,
281
+ attn_num_head_channels=1,
282
+ cross_attention_dim=1280,
283
+ output_scale_factor=1.0,
284
+ downsample_padding=1,
285
+ add_downsample=True,
286
+ dual_cross_attention=False,
287
+ use_linear_projection=False,
288
+ only_cross_attention=False,
289
+ upcast_attention=False,
290
+
291
+ use_motion_module=None,
292
+ motion_module_type=None,
293
+ motion_module_kwargs=None,
294
+ ):
295
+ super().__init__()
296
+ resnets = []
297
+ attentions = []
298
+ motion_modules = []
299
+
300
+ self.has_cross_attention = True
301
+ self.attn_num_head_channels = attn_num_head_channels
302
+
303
+ for i in range(num_layers):
304
+ in_channels = in_channels if i == 0 else out_channels
305
+ resnets.append(
306
+ ResnetBlock2D(
307
+ in_channels=in_channels,
308
+ out_channels=out_channels,
309
+ temb_channels=temb_channels,
310
+ eps=resnet_eps,
311
+ groups=resnet_groups,
312
+ dropout=dropout,
313
+ time_embedding_norm=resnet_time_scale_shift,
314
+ non_linearity=resnet_act_fn,
315
+ output_scale_factor=output_scale_factor,
316
+ pre_norm=resnet_pre_norm,
317
+ )
318
+ )
319
+
320
+ if dual_cross_attention:
321
+ raise NotImplementedError
322
+ attentions.append(
323
+ Transformer2DModel(
324
+ attn_num_head_channels,
325
+ out_channels // attn_num_head_channels,
326
+ in_channels=out_channels,
327
+ num_layers=1,
328
+ cross_attention_dim=cross_attention_dim,
329
+ norm_num_groups=resnet_groups,
330
+ use_linear_projection=use_linear_projection,
331
+ only_cross_attention=only_cross_attention,
332
+ upcast_attention=upcast_attention,
333
+ )
334
+ )
335
+ motion_modules.append(
336
+ get_motion_module(
337
+ in_channels=out_channels,
338
+ motion_module_type=motion_module_type,
339
+ motion_module_kwargs=motion_module_kwargs,
340
+ ) if use_motion_module else None
341
+ )
342
+
343
+ self.attentions = nn.ModuleList(attentions)
344
+ self.resnets = nn.ModuleList(resnets)
345
+ self.motion_modules = nn.ModuleList(motion_modules) if use_motion_module else motion_modules
346
+
347
+ if add_downsample:
348
+ self.downsamplers = nn.ModuleList(
349
+ [
350
+ Downsample2D(
351
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
352
+ )
353
+ ]
354
+ )
355
+ else:
356
+ self.downsamplers = None
357
+
358
+ self.gradient_checkpointing = False
359
+
360
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None,
361
+ motion_module_alpha=1., cross_attention_kwargs={}, motion_cross_attention_kwargs={}):
362
+ video_length = hidden_states.shape[2]
363
+ temb_repeated = repeat(temb, "b c -> (b f) c", f=video_length)
364
+
365
+ output_states = ()
366
+
367
+ lora_scale = getattr(self, "lora_scale", None)
368
+ if lora_scale != None:
369
+ cross_attention_kwargs["scale"] = lora_scale
370
+ motion_lora_scale = getattr(self, "motion_lora_scale", None)
371
+ if motion_lora_scale != None:
372
+ if motion_cross_attention_kwargs is None:
373
+ motion_cross_attention_kwargs = {"scale": motion_lora_scale}
374
+ else:
375
+ motion_cross_attention_kwargs.update({"scale": motion_lora_scale})
376
+
377
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
378
+ if self.training and self.gradient_checkpointing:
379
+ raise NotImplementedError
380
+
381
+ def create_custom_forward(module, return_dict=None):
382
+ def custom_forward(*inputs):
383
+ if return_dict is not None:
384
+ return module(*inputs, return_dict=return_dict)
385
+ else:
386
+ return module(*inputs)
387
+
388
+ return custom_forward
389
+
390
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
391
+ hidden_states = torch.utils.checkpoint.checkpoint(
392
+ create_custom_forward(attn, return_dict=False),
393
+ hidden_states,
394
+ encoder_hidden_states,
395
+ )[0]
396
+ if motion_module is not None:
397
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module),
398
+ hidden_states.requires_grad_(), temb,
399
+ encoder_hidden_states)
400
+
401
+ else:
402
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
403
+ hidden_states = resnet(hidden_states, temb_repeated)
404
+ hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length)
405
+
406
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
407
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states,
408
+ cross_attention_kwargs=cross_attention_kwargs).sample
409
+ hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length)
410
+
411
+ # motion module
412
+ if motion_module is not None:
413
+ # hidden_states = motion_module_alpha * motion_module(hidden_states, temb=temb, encoder_hidden_states=encoder_hidden_states) + hidden_states
414
+ hidden_states = motion_module(hidden_states, temb=temb, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=motion_cross_attention_kwargs)
415
+
416
+ output_states += (hidden_states,)
417
+
418
+ if self.downsamplers is not None:
419
+ for downsampler in self.downsamplers:
420
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
421
+ hidden_states = downsampler(hidden_states)
422
+ hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length)
423
+
424
+ output_states += (hidden_states,)
425
+
426
+ return hidden_states, output_states
427
+
428
+
429
+ class DownBlock3D(nn.Module):
430
+ def __init__(
431
+ self,
432
+ in_channels: int,
433
+ out_channels: int,
434
+ temb_channels: int,
435
+ dropout: float = 0.0,
436
+ num_layers: int = 1,
437
+ resnet_eps: float = 1e-6,
438
+ resnet_time_scale_shift: str = "default",
439
+ resnet_act_fn: str = "swish",
440
+ resnet_groups: int = 32,
441
+ resnet_pre_norm: bool = True,
442
+ output_scale_factor=1.0,
443
+ add_downsample=True,
444
+ downsample_padding=1,
445
+
446
+ use_motion_module=None,
447
+ motion_module_type=None,
448
+ motion_module_kwargs=None,
449
+ ):
450
+ super().__init__()
451
+ resnets = []
452
+ motion_modules = []
453
+
454
+ for i in range(num_layers):
455
+ in_channels = in_channels if i == 0 else out_channels
456
+ resnets.append(
457
+ ResnetBlock2D(
458
+ in_channels=in_channels,
459
+ out_channels=out_channels,
460
+ temb_channels=temb_channels,
461
+ eps=resnet_eps,
462
+ groups=resnet_groups,
463
+ dropout=dropout,
464
+ time_embedding_norm=resnet_time_scale_shift,
465
+ non_linearity=resnet_act_fn,
466
+ output_scale_factor=output_scale_factor,
467
+ pre_norm=resnet_pre_norm,
468
+ )
469
+ )
470
+ motion_modules.append(
471
+ get_motion_module(
472
+ in_channels=out_channels,
473
+ motion_module_type=motion_module_type,
474
+ motion_module_kwargs=motion_module_kwargs,
475
+ ) if use_motion_module else None
476
+ )
477
+
478
+ self.resnets = nn.ModuleList(resnets)
479
+ self.motion_modules = nn.ModuleList(motion_modules) if use_motion_module else motion_modules
480
+
481
+ if add_downsample:
482
+ self.downsamplers = nn.ModuleList(
483
+ [
484
+ Downsample2D(
485
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
486
+ )
487
+ ]
488
+ )
489
+ else:
490
+ self.downsamplers = None
491
+
492
+ self.gradient_checkpointing = False
493
+
494
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, motion_module_alpha=1.,
495
+ motion_cross_attention_kwargs={}, **kwargs):
496
+ video_length = hidden_states.shape[2]
497
+ temb_repeated = repeat(temb, "b c -> (b f) c", f=video_length)
498
+ output_states = ()
499
+ motion_lora_scale = getattr(self, "motion_lora_scale", None)
500
+ if motion_lora_scale != None:
501
+ if motion_cross_attention_kwargs is None:
502
+ motion_cross_attention_kwargs = {"scale": motion_lora_scale}
503
+ else:
504
+ motion_cross_attention_kwargs.update({"scale": motion_lora_scale})
505
+
506
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
507
+ if self.training and self.gradient_checkpointing:
508
+ raise NotImplementedError
509
+
510
+ def create_custom_forward(module):
511
+ def custom_forward(*inputs):
512
+ return module(*inputs)
513
+
514
+ return custom_forward
515
+
516
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
517
+ if motion_module is not None:
518
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module),
519
+ hidden_states.requires_grad_(), temb,
520
+ encoder_hidden_states)
521
+ else:
522
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
523
+ hidden_states = resnet(hidden_states, temb_repeated)
524
+ hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length)
525
+
526
+ # motion module
527
+ if motion_module is not None:
528
+ hidden_states = motion_module(hidden_states, temb=temb, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=motion_cross_attention_kwargs)
529
+
530
+ output_states += (hidden_states,)
531
+
532
+ if self.downsamplers is not None:
533
+ for downsampler in self.downsamplers:
534
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
535
+ hidden_states = downsampler(hidden_states)
536
+ hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length)
537
+
538
+ output_states += (hidden_states,)
539
+
540
+ return hidden_states, output_states
541
+
542
+
543
+ class CrossAttnUpBlock3D(nn.Module):
544
+ def __init__(
545
+ self,
546
+ in_channels: int,
547
+ out_channels: int,
548
+ prev_output_channel: int,
549
+ temb_channels: int,
550
+ dropout: float = 0.0,
551
+ num_layers: int = 1,
552
+ resnet_eps: float = 1e-6,
553
+ resnet_time_scale_shift: str = "default",
554
+ resnet_act_fn: str = "swish",
555
+ resnet_groups: int = 32,
556
+ resnet_pre_norm: bool = True,
557
+ attn_num_head_channels=1,
558
+ cross_attention_dim=1280,
559
+ output_scale_factor=1.0,
560
+ add_upsample=True,
561
+ dual_cross_attention=False,
562
+ use_linear_projection=False,
563
+ only_cross_attention=False,
564
+ upcast_attention=False,
565
+
566
+ use_motion_module=None,
567
+ motion_module_type=None,
568
+ motion_module_kwargs=None,
569
+ ):
570
+ super().__init__()
571
+ resnets = []
572
+ attentions = []
573
+ motion_modules = []
574
+
575
+ self.has_cross_attention = True
576
+ self.attn_num_head_channels = attn_num_head_channels
577
+
578
+ for i in range(num_layers):
579
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
580
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
581
+
582
+ resnets.append(
583
+ ResnetBlock2D(
584
+ in_channels=resnet_in_channels + res_skip_channels,
585
+ out_channels=out_channels,
586
+ temb_channels=temb_channels,
587
+ eps=resnet_eps,
588
+ groups=resnet_groups,
589
+ dropout=dropout,
590
+ time_embedding_norm=resnet_time_scale_shift,
591
+ non_linearity=resnet_act_fn,
592
+ output_scale_factor=output_scale_factor,
593
+ pre_norm=resnet_pre_norm,
594
+ )
595
+ )
596
+
597
+ if dual_cross_attention:
598
+ raise NotImplementedError
599
+ attentions.append(
600
+ Transformer2DModel(
601
+ attn_num_head_channels,
602
+ out_channels // attn_num_head_channels,
603
+ in_channels=out_channels,
604
+ num_layers=1,
605
+ cross_attention_dim=cross_attention_dim,
606
+ norm_num_groups=resnet_groups,
607
+ use_linear_projection=use_linear_projection,
608
+ only_cross_attention=only_cross_attention,
609
+ upcast_attention=upcast_attention,
610
+ )
611
+ )
612
+ motion_modules.append(
613
+ get_motion_module(
614
+ in_channels=out_channels,
615
+ motion_module_type=motion_module_type,
616
+ motion_module_kwargs=motion_module_kwargs,
617
+ ) if use_motion_module else None
618
+ )
619
+
620
+ self.attentions = nn.ModuleList(attentions)
621
+ self.resnets = nn.ModuleList(resnets)
622
+ self.motion_modules = nn.ModuleList(motion_modules) if use_motion_module else motion_modules
623
+
624
+ if add_upsample:
625
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
626
+ else:
627
+ self.upsamplers = None
628
+
629
+ self.gradient_checkpointing = False
630
+
631
+ def forward(
632
+ self,
633
+ hidden_states,
634
+ res_hidden_states_tuple,
635
+ temb=None,
636
+ encoder_hidden_states=None,
637
+ upsample_size=None,
638
+ attention_mask=None,
639
+ motion_module_alpha=1.,
640
+ cross_attention_kwargs=None,
641
+ motion_cross_attention_kwargs={}
642
+ ):
643
+ video_length = hidden_states.shape[2]
644
+ temb_repeated = repeat(temb, "b c -> (b f) c", f=video_length)
645
+
646
+ lora_scale = getattr(self, "lora_scale", None)
647
+ if lora_scale != None:
648
+ cross_attention_kwargs = {"scale": lora_scale}
649
+ motion_lora_scale = getattr(self, "motion_lora_scale", None)
650
+ if motion_lora_scale != None:
651
+ if motion_cross_attention_kwargs is None:
652
+ motion_cross_attention_kwargs = {"scale": motion_lora_scale}
653
+ else:
654
+ motion_cross_attention_kwargs.update({"scale": motion_lora_scale})
655
+
656
+ for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
657
+ # pop res hidden states
658
+ res_hidden_states = res_hidden_states_tuple[-1]
659
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
660
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
661
+
662
+ if self.training and self.gradient_checkpointing:
663
+ raise NotImplementedError
664
+
665
+ def create_custom_forward(module, return_dict=None):
666
+ def custom_forward(*inputs):
667
+ if return_dict is not None:
668
+ return module(*inputs, return_dict=return_dict)
669
+ else:
670
+ return module(*inputs)
671
+
672
+ return custom_forward
673
+
674
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
675
+ hidden_states = torch.utils.checkpoint.checkpoint(
676
+ create_custom_forward(attn, return_dict=False),
677
+ hidden_states,
678
+ encoder_hidden_states,
679
+ )[0]
680
+ if motion_module is not None:
681
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module),
682
+ hidden_states.requires_grad_(), temb,
683
+ encoder_hidden_states)
684
+
685
+ else:
686
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
687
+ hidden_states = resnet(hidden_states, temb_repeated)
688
+ hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length)
689
+
690
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
691
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states,
692
+ cross_attention_kwargs=cross_attention_kwargs).sample
693
+ hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length)
694
+
695
+ # motion module
696
+ if motion_module is not None:
697
+ # hidden_states = motion_module_alpha * motion_module(hidden_states, temb=temb, encoder_hidden_states=encoder_hidden_states) + hidden_states
698
+ hidden_states = motion_module(hidden_states, temb=temb, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=motion_cross_attention_kwargs)
699
+
700
+ if self.upsamplers is not None:
701
+ for upsampler in self.upsamplers:
702
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
703
+ hidden_states = upsampler(hidden_states, upsample_size)
704
+ hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length)
705
+
706
+ return hidden_states
707
+
708
+
709
+ class UpBlock3D(nn.Module):
710
+ def __init__(
711
+ self,
712
+ in_channels: int,
713
+ prev_output_channel: int,
714
+ out_channels: int,
715
+ temb_channels: int,
716
+ dropout: float = 0.0,
717
+ num_layers: int = 1,
718
+ resnet_eps: float = 1e-6,
719
+ resnet_time_scale_shift: str = "default",
720
+ resnet_act_fn: str = "swish",
721
+ resnet_groups: int = 32,
722
+ resnet_pre_norm: bool = True,
723
+ output_scale_factor=1.0,
724
+ add_upsample=True,
725
+
726
+ use_motion_module=None,
727
+ motion_module_type=None,
728
+ motion_module_kwargs=None,
729
+ ):
730
+ super().__init__()
731
+ resnets = []
732
+ motion_modules = []
733
+
734
+ for i in range(num_layers):
735
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
736
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
737
+
738
+ resnets.append(
739
+ ResnetBlock2D(
740
+ in_channels=resnet_in_channels + res_skip_channels,
741
+ out_channels=out_channels,
742
+ temb_channels=temb_channels,
743
+ eps=resnet_eps,
744
+ groups=resnet_groups,
745
+ dropout=dropout,
746
+ time_embedding_norm=resnet_time_scale_shift,
747
+ non_linearity=resnet_act_fn,
748
+ output_scale_factor=output_scale_factor,
749
+ pre_norm=resnet_pre_norm,
750
+ )
751
+ )
752
+ motion_modules.append(
753
+ get_motion_module(
754
+ in_channels=out_channels,
755
+ motion_module_type=motion_module_type,
756
+ motion_module_kwargs=motion_module_kwargs,
757
+ ) if use_motion_module else None
758
+ )
759
+
760
+ self.resnets = nn.ModuleList(resnets)
761
+ self.motion_modules = nn.ModuleList(motion_modules) if use_motion_module else motion_modules
762
+
763
+ if add_upsample:
764
+ self.upsamplers = nn.ModuleList([Upsample2D(out_channels, use_conv=True, out_channels=out_channels)])
765
+ else:
766
+ self.upsamplers = None
767
+
768
+ self.gradient_checkpointing = False
769
+
770
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, encoder_hidden_states=None,
771
+ motion_module_alpha=1., motion_cross_attention_kwargs={}, **kwargs):
772
+ video_length = hidden_states.shape[2]
773
+ temb_repeated = repeat(temb, "b c -> (b f) c", f=video_length)
774
+
775
+ motion_lora_scale = getattr(self, "motion_lora_scale", None)
776
+ if motion_lora_scale != None:
777
+ if motion_cross_attention_kwargs is None:
778
+ motion_cross_attention_kwargs = {"scale": motion_lora_scale}
779
+ else:
780
+ motion_cross_attention_kwargs.update({"scale": motion_lora_scale})
781
+
782
+ for resnet, motion_module in zip(self.resnets, self.motion_modules):
783
+ # pop res hidden states
784
+ res_hidden_states = res_hidden_states_tuple[-1]
785
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
786
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
787
+
788
+ if self.training and self.gradient_checkpointing:
789
+ raise NotImplementedError
790
+
791
+ def create_custom_forward(module):
792
+ def custom_forward(*inputs):
793
+ return module(*inputs)
794
+
795
+ return custom_forward
796
+
797
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
798
+ if motion_module is not None:
799
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(motion_module),
800
+ hidden_states.requires_grad_(), temb,
801
+ encoder_hidden_states)
802
+ else:
803
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
804
+ hidden_states = resnet(hidden_states, temb_repeated)
805
+ hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length)
806
+
807
+ # motion module
808
+ if motion_module is not None:
809
+ hidden_states = motion_module(hidden_states, temb=temb, encoder_hidden_states=encoder_hidden_states, cross_attention_kwargs=motion_cross_attention_kwargs)
810
+
811
+ if self.upsamplers is not None:
812
+ for upsampler in self.upsamplers:
813
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
814
+ hidden_states = upsampler(hidden_states, upsample_size)
815
+ hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length)
816
+
817
+ return hidden_states
818
+
genphoto/pipelines/pipeline_animation.py CHANGED
@@ -1,3 +1,719 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:453fc7220c98fbe0fa70b19aade5b4403e470c09efed70147f2fcf35dd782d5b
3
- size 34090
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/showlab/Tune-A-Video/blob/main/tuneavideo/pipelines/pipeline_tuneavideo.py
2
+
3
+ import inspect
4
+ import torch
5
+
6
+ import numpy as np
7
+
8
+ from typing import Callable, List, Optional, Union
9
+ from dataclasses import dataclass
10
+ from diffusers.utils import is_accelerate_available
11
+ from packaging import version
12
+ from einops import rearrange
13
+ from transformers import CLIPTextModel, CLIPTokenizer
14
+ from diffusers.configuration_utils import FrozenDict
15
+ from diffusers.models import AutoencoderKL
16
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
17
+ from diffusers.schedulers import (
18
+ DDIMScheduler,
19
+ DPMSolverMultistepScheduler,
20
+ EulerAncestralDiscreteScheduler,
21
+ EulerDiscreteScheduler,
22
+ LMSDiscreteScheduler,
23
+ PNDMScheduler,
24
+ )
25
+ from diffusers.loaders import LoraLoaderMixin
26
+ from diffusers.utils import deprecate, logging, BaseOutput
27
+
28
+ from genphoto.models.camera_adaptor import CameraCameraEncoder
29
+ from genphoto.models.unet import UNet3DConditionModel
30
+
31
+
32
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
+
34
+
35
+ @dataclass
36
+ class AnimationPipelineOutput(BaseOutput):
37
+ videos: Union[torch.Tensor, np.ndarray]
38
+
39
+
40
+ class AnimationPipeline(DiffusionPipeline, LoraLoaderMixin):
41
+ _optional_components = []
42
+
43
+ def __init__(
44
+ self,
45
+ vae: AutoencoderKL,
46
+ text_encoder: CLIPTextModel,
47
+ tokenizer: CLIPTokenizer,
48
+ unet: UNet3DConditionModel,
49
+ scheduler: Union[
50
+ DDIMScheduler,
51
+ PNDMScheduler,
52
+ LMSDiscreteScheduler,
53
+ EulerDiscreteScheduler,
54
+ EulerAncestralDiscreteScheduler,
55
+ DPMSolverMultistepScheduler,
56
+ ],
57
+ ):
58
+ super().__init__()
59
+
60
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
61
+ deprecation_message = (
62
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
63
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
64
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
65
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
66
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
67
+ " file"
68
+ )
69
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
70
+ new_config = dict(scheduler.config)
71
+ new_config["steps_offset"] = 1
72
+ scheduler._internal_dict = FrozenDict(new_config)
73
+
74
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
75
+ deprecation_message = (
76
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
77
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
78
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
79
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
80
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
81
+ )
82
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
83
+ new_config = dict(scheduler.config)
84
+ new_config["clip_sample"] = False
85
+ scheduler._internal_dict = FrozenDict(new_config)
86
+
87
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
88
+ version.parse(unet.config._diffusers_version).base_version
89
+ ) < version.parse("0.9.0.dev0")
90
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
91
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
92
+ deprecation_message = (
93
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
94
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
95
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
96
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
97
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
98
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
99
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
100
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
101
+ " the `unet/config.json` file"
102
+ )
103
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
104
+ new_config = dict(unet.config)
105
+ new_config["sample_size"] = 64
106
+ unet._internal_dict = FrozenDict(new_config)
107
+
108
+ self.register_modules(
109
+ vae=vae,
110
+ text_encoder=text_encoder,
111
+ tokenizer=tokenizer,
112
+ unet=unet,
113
+ scheduler=scheduler,
114
+ )
115
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
116
+
117
+ def enable_vae_slicing(self):
118
+ self.vae.enable_slicing()
119
+
120
+ def disable_vae_slicing(self):
121
+ self.vae.disable_slicing()
122
+
123
+ def enable_sequential_cpu_offload(self, gpu_id=0):
124
+ if is_accelerate_available():
125
+ from accelerate import cpu_offload
126
+ else:
127
+ raise ImportError("Please install accelerate via `pip install accelerate`")
128
+
129
+ device = torch.device(f"cuda:{gpu_id}")
130
+
131
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
132
+ if cpu_offloaded_model is not None:
133
+ cpu_offload(cpu_offloaded_model, device)
134
+
135
+
136
+ @property
137
+ def _execution_device(self):
138
+ if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
139
+ return self.device
140
+ for module in self.unet.modules():
141
+ if (
142
+ hasattr(module, "_hf_hook")
143
+ and hasattr(module._hf_hook, "execution_device")
144
+ and module._hf_hook.execution_device is not None
145
+ ):
146
+ return torch.device(module._hf_hook.execution_device)
147
+ return self.device
148
+
149
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
150
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
151
+
152
+ text_inputs = self.tokenizer(
153
+ prompt,
154
+ padding="max_length",
155
+ max_length=self.tokenizer.model_max_length,
156
+ truncation=True,
157
+ return_tensors="pt",
158
+ )
159
+ text_input_ids = text_inputs.input_ids
160
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
161
+
162
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
163
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
164
+ logger.warning(
165
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
166
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
167
+ )
168
+
169
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
170
+ attention_mask = text_inputs.attention_mask.to(device)
171
+ else:
172
+ attention_mask = None
173
+
174
+ text_embeddings = self.text_encoder(
175
+ text_input_ids.to(device),
176
+ attention_mask=attention_mask,
177
+ )
178
+ text_embeddings = text_embeddings[0]
179
+
180
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
181
+ bs_embed, seq_len, _ = text_embeddings.shape
182
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
183
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
184
+
185
+ # get unconditional embeddings for classifier free guidance
186
+ if do_classifier_free_guidance:
187
+ uncond_tokens: List[str]
188
+ if negative_prompt is None:
189
+ uncond_tokens = [""] * batch_size
190
+ elif type(prompt) is not type(negative_prompt):
191
+ raise TypeError(
192
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
193
+ f" {type(prompt)}."
194
+ )
195
+ elif isinstance(negative_prompt, str):
196
+ uncond_tokens = [negative_prompt]
197
+ elif batch_size != len(negative_prompt):
198
+ raise ValueError(
199
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
200
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
201
+ " the batch size of `prompt`."
202
+ )
203
+ else:
204
+ uncond_tokens = negative_prompt
205
+
206
+ max_length = text_input_ids.shape[-1]
207
+ uncond_input = self.tokenizer(
208
+ uncond_tokens,
209
+ padding="max_length",
210
+ max_length=max_length,
211
+ truncation=True,
212
+ return_tensors="pt",
213
+ )
214
+
215
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
216
+ attention_mask = uncond_input.attention_mask.to(device)
217
+ else:
218
+ attention_mask = None
219
+
220
+ uncond_embeddings = self.text_encoder(
221
+ uncond_input.input_ids.to(device),
222
+ attention_mask=attention_mask,
223
+ )
224
+ uncond_embeddings = uncond_embeddings[0]
225
+
226
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
227
+ seq_len = uncond_embeddings.shape[1]
228
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
229
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
230
+
231
+ # For classifier free guidance, we need to do two forward passes.
232
+ # Here we concatenate the unconditional and text embeddings into a single batch
233
+ # to avoid doing two forward passes
234
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
235
+
236
+ return text_embeddings
237
+
238
+ def decode_latents(self, latents):
239
+ video_length = latents.shape[2]
240
+ latents = 1 / 0.18215 * latents
241
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
242
+ # video = self.vae.decode(latents).sample
243
+ video = []
244
+ for frame_idx in range(latents.shape[0]):
245
+ video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
246
+ video = torch.cat(video)
247
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
248
+ video = (video / 2 + 0.5).clamp(0, 1)
249
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
250
+ video = video.cpu().float().numpy()
251
+ return video
252
+
253
+ def prepare_extra_step_kwargs(self, generator, eta):
254
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
255
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
256
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
257
+ # and should be between [0, 1]
258
+
259
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
260
+ extra_step_kwargs = {}
261
+ if accepts_eta:
262
+ extra_step_kwargs["eta"] = eta
263
+
264
+ # check if the scheduler accepts generator
265
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
266
+ if accepts_generator:
267
+ extra_step_kwargs["generator"] = generator
268
+ return extra_step_kwargs
269
+
270
+ def check_inputs(self, prompt, height, width, callback_steps):
271
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
272
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
273
+
274
+ if height % 8 != 0 or width % 8 != 0:
275
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
276
+
277
+ if (callback_steps is None) or (
278
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
279
+ ):
280
+ raise ValueError(
281
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
282
+ f" {type(callback_steps)}."
283
+ )
284
+
285
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
286
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
287
+ if isinstance(generator, list) and len(generator) != batch_size:
288
+ raise ValueError(
289
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
290
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
291
+ )
292
+ if latents is None:
293
+ rand_device = "cpu" if device.type == "mps" else device
294
+
295
+ if isinstance(generator, list):
296
+ shape = shape
297
+ # shape = (1,) + shape[1:]
298
+ latents = [
299
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
300
+ for i in range(batch_size)
301
+ ]
302
+ latents = torch.cat(latents, dim=0).to(device)
303
+ else:
304
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
305
+ else:
306
+ if latents.shape != shape:
307
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
308
+ latents = latents.to(device)
309
+
310
+ # scale the initial noise by the standard deviation required by the scheduler
311
+ latents = latents * self.scheduler.init_noise_sigma
312
+ return latents
313
+
314
+ @torch.no_grad()
315
+ def __call__(
316
+ self,
317
+ prompt: Union[str, List[str]],
318
+ video_length: Optional[int],
319
+ height: Optional[int] = None,
320
+ width: Optional[int] = None,
321
+ num_inference_steps: int = 50,
322
+ guidance_scale: float = 7.5,
323
+ negative_prompt: Optional[Union[str, List[str]]] = None,
324
+ num_videos_per_prompt: Optional[int] = 1,
325
+ eta: float = 0.0,
326
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
327
+ latents: Optional[torch.FloatTensor] = None,
328
+ output_type: Optional[str] = "tensor",
329
+ return_dict: bool = True,
330
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
331
+ callback_steps: Optional[int] = 1,
332
+
333
+ multidiff_total_steps: int = 1,
334
+ multidiff_overlaps: int = 12,
335
+ **kwargs,
336
+ ):
337
+ # Default height and width to unet
338
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
339
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
340
+
341
+ # Check inputs. Raise error if not correct
342
+ self.check_inputs(prompt, height, width, callback_steps)
343
+
344
+ # Define call parameters
345
+ # batch_size = 1 if isinstance(prompt, str) else len(prompt)
346
+ batch_size = 1
347
+ if latents is not None:
348
+ batch_size = latents.shape[0]
349
+ if isinstance(prompt, list):
350
+ batch_size = len(prompt)
351
+
352
+ device = self._execution_device
353
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
354
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
355
+ # corresponds to doing no classifier free guidance.
356
+ do_classifier_free_guidance = guidance_scale > 1.0
357
+
358
+ # Encode input prompt
359
+ prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
360
+ if negative_prompt is not None:
361
+ negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
362
+ text_embeddings = self._encode_prompt(
363
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
364
+ )
365
+
366
+ # Prepare timesteps
367
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
368
+ timesteps = self.scheduler.timesteps
369
+
370
+ # Prepare latent variables
371
+ single_model_length = video_length
372
+ video_length = multidiff_total_steps * (video_length - multidiff_overlaps) + multidiff_overlaps
373
+ num_channels_latents = self.unet.in_channels
374
+ latents = self.prepare_latents(
375
+ batch_size * num_videos_per_prompt,
376
+ num_channels_latents,
377
+ video_length,
378
+ height,
379
+ width,
380
+ text_embeddings.dtype,
381
+ device,
382
+ generator,
383
+ latents,
384
+ )
385
+ latents_dtype = latents.dtype
386
+
387
+ # Prepare extra step kwargs.
388
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
389
+
390
+ # Denoising loop
391
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
392
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
393
+ for i, t in enumerate(timesteps):
394
+ noise_pred_full = torch.zeros_like(latents).to(latents.device)
395
+ mask_full = torch.zeros_like(latents).to(latents.device)
396
+ noise_preds = []
397
+
398
+ for multidiff_step in range(multidiff_total_steps):
399
+ start_idx = multidiff_step * (single_model_length - multidiff_overlaps)
400
+ latent_partial = latents[:, :, start_idx: start_idx + single_model_length].contiguous()
401
+ mask_full[:, :, start_idx: start_idx + single_model_length] += 1
402
+
403
+ # expand the latents if we are doing classifier free guidance
404
+ latent_model_input = torch.cat([latent_partial] * 2) if do_classifier_free_guidance else latent_partial
405
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
406
+
407
+ # predict the noise residual
408
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample.to(dtype=latents_dtype)
409
+
410
+ # perform guidance
411
+ if do_classifier_free_guidance:
412
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
413
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
414
+ noise_preds.append(noise_pred)
415
+
416
+ for pred_idx, noise_pred in enumerate(noise_preds):
417
+ start_idx = pred_idx * (single_model_length - multidiff_overlaps)
418
+ noise_pred_full[:, :, start_idx: start_idx + single_model_length] += noise_pred / mask_full[:, :, start_idx: start_idx + single_model_length]
419
+
420
+ # compute the previous noisy sample x_t -> x_t-1
421
+ latents = self.scheduler.step(noise_pred_full, t, latents, **extra_step_kwargs).prev_sample
422
+
423
+ # call the callback, if provided
424
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
425
+ progress_bar.update()
426
+ if callback is not None and i % callback_steps == 0:
427
+ callback(i, t, latents)
428
+
429
+ # Post-processing
430
+ video = self.decode_latents(latents)
431
+
432
+ # Convert to tensor
433
+ if output_type == "tensor":
434
+ video = torch.from_numpy(video)
435
+
436
+ if not return_dict:
437
+ return video
438
+
439
+ return AnimationPipelineOutput(videos=video)
440
+
441
+
442
+ class GenPhotoPipeline(AnimationPipeline):
443
+ _optional_components = []
444
+
445
+ def __init__(self,
446
+ vae: AutoencoderKL,
447
+ text_encoder: CLIPTextModel,
448
+ tokenizer: CLIPTokenizer,
449
+ unet: UNet3DConditionModel,
450
+ scheduler: Union[
451
+ DDIMScheduler,
452
+ PNDMScheduler,
453
+ LMSDiscreteScheduler,
454
+ EulerDiscreteScheduler,
455
+ EulerAncestralDiscreteScheduler,
456
+ DPMSolverMultistepScheduler],
457
+ camera_encoder: CameraCameraEncoder):
458
+
459
+ super().__init__(vae, text_encoder, tokenizer, unet, scheduler)
460
+
461
+ self.register_modules(
462
+ camera_encoder=camera_encoder
463
+ )
464
+
465
+ def decode_latents(self, latents):
466
+ video_length = latents.shape[2]
467
+ latents = 1 / 0.18215 * latents
468
+ latents = rearrange(latents, "b c f h w -> (b f) c h w")
469
+ # video = self.vae.decode(latents).sample
470
+ video = []
471
+ for frame_idx in range(latents.shape[0]):
472
+ video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
473
+ video = torch.cat(video)
474
+ video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
475
+ video = (video / 2 + 0.5).clamp(0, 1)
476
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
477
+ video = video.cpu().float().numpy()
478
+ return video
479
+
480
+ def _encode_prompt(self, prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt):
481
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
482
+
483
+ text_inputs = self.tokenizer(
484
+ prompt,
485
+ padding="max_length",
486
+ max_length=self.tokenizer.model_max_length,
487
+ truncation=True,
488
+ return_tensors="pt",
489
+ )
490
+ text_input_ids = text_inputs.input_ids
491
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
492
+
493
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
494
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
495
+ logger.warning(
496
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
497
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
498
+ )
499
+
500
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
501
+ attention_mask = text_inputs.attention_mask.to(device)
502
+ else:
503
+ attention_mask = None
504
+
505
+ text_embeddings = self.text_encoder(
506
+ text_input_ids.to(device),
507
+ attention_mask=attention_mask,
508
+ )
509
+ text_embeddings = text_embeddings[0]
510
+
511
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
512
+ bs_embed, seq_len, _ = text_embeddings.shape
513
+ text_embeddings = text_embeddings.repeat(1, num_videos_per_prompt, 1)
514
+ text_embeddings = text_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
515
+
516
+ # get unconditional embeddings for classifier free guidance
517
+ if do_classifier_free_guidance:
518
+ uncond_tokens: List[str]
519
+ if negative_prompt is None:
520
+ uncond_tokens = [""] * batch_size
521
+ elif type(prompt) is not type(negative_prompt):
522
+ raise TypeError(
523
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
524
+ f" {type(prompt)}."
525
+ )
526
+ elif isinstance(negative_prompt, str):
527
+ uncond_tokens = [negative_prompt]
528
+ elif batch_size != len(negative_prompt):
529
+ raise ValueError(
530
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
531
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
532
+ " the batch size of `prompt`."
533
+ )
534
+ else:
535
+ uncond_tokens = negative_prompt
536
+
537
+ max_length = text_input_ids.shape[-1]
538
+ uncond_input = self.tokenizer(
539
+ uncond_tokens,
540
+ padding="max_length",
541
+ max_length=max_length,
542
+ truncation=True,
543
+ return_tensors="pt",
544
+ )
545
+
546
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
547
+ attention_mask = uncond_input.attention_mask.to(device)
548
+ else:
549
+ attention_mask = None
550
+
551
+ uncond_embeddings = self.text_encoder(
552
+ uncond_input.input_ids.to(device),
553
+ attention_mask=attention_mask,
554
+ )
555
+ uncond_embeddings = uncond_embeddings[0]
556
+
557
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
558
+ seq_len = uncond_embeddings.shape[1]
559
+ uncond_embeddings = uncond_embeddings.repeat(1, num_videos_per_prompt, 1)
560
+ uncond_embeddings = uncond_embeddings.view(batch_size * num_videos_per_prompt, seq_len, -1)
561
+
562
+ # For classifier free guidance, we need to do two forward passes.
563
+ # Here we concatenate the unconditional and text embeddings into a single batch
564
+ # to avoid doing two forward passes
565
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
566
+
567
+ return text_embeddings
568
+
569
+ @torch.no_grad()
570
+ def __call__(
571
+ self,
572
+ prompt: Union[str, List[str]],
573
+ camera_embedding: torch.FloatTensor,
574
+ video_length: Optional[int],
575
+ height: Optional[int] = None,
576
+ width: Optional[int] = None,
577
+ num_inference_steps: int = 50,
578
+ guidance_scale: float = 7.5,
579
+ negative_prompt: Optional[Union[str, List[str]]] = None,
580
+ num_videos_per_prompt: Optional[int] = 1,
581
+ eta: float = 0.0,
582
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
583
+ latents: Optional[torch.FloatTensor] = None,
584
+ output_type: Optional[str] = "tensor",
585
+ return_dict: bool = True,
586
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
587
+ callback_steps: Optional[int] = 1,
588
+ multidiff_total_steps: int = 1,
589
+ multidiff_overlaps: int = 12,
590
+ **kwargs,
591
+ ):
592
+ # Default height and width to unet
593
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
594
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
595
+
596
+ # Check inputs. Raise error if not correct
597
+ self.check_inputs(prompt, height, width, callback_steps)
598
+
599
+ # Define call parameters
600
+ # batch_size = 1 if isinstance(prompt, str) else len(prompt)
601
+ batch_size = 1
602
+ if latents is not None:
603
+ batch_size = latents.shape[0]
604
+ if isinstance(prompt, list):
605
+ batch_size = len(prompt)
606
+
607
+ device = camera_embedding[0].device if isinstance(camera_embedding, list) else camera_embedding.device
608
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
609
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
610
+ # corresponds to doing no classifier free guidance.
611
+ do_classifier_free_guidance = guidance_scale > 1.0
612
+
613
+ # Encode input prompt
614
+ prompt = prompt if isinstance(prompt, list) else [prompt] * batch_size
615
+ if negative_prompt is not None:
616
+ negative_prompt = negative_prompt if isinstance(negative_prompt, list) else [negative_prompt] * batch_size
617
+ text_embeddings = self._encode_prompt(
618
+ prompt, device, num_videos_per_prompt, do_classifier_free_guidance, negative_prompt
619
+ ) # [2bf, l, c]
620
+
621
+ # Prepare timesteps
622
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
623
+ timesteps = self.scheduler.timesteps
624
+
625
+ # Prepare latent variables
626
+ single_model_length = video_length
627
+ video_length = multidiff_total_steps * (video_length - multidiff_overlaps) + multidiff_overlaps
628
+ num_channels_latents = self.unet.in_channels
629
+ latents = self.prepare_latents(
630
+ batch_size * num_videos_per_prompt,
631
+ num_channels_latents,
632
+ video_length,
633
+ height,
634
+ width,
635
+ text_embeddings.dtype,
636
+ device,
637
+ generator,
638
+ latents,
639
+ ) # b c f h w
640
+ latents_dtype = latents.dtype
641
+
642
+ # Prepare extra step kwargs.
643
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
644
+ if isinstance(camera_embedding, list):
645
+ assert all([x.ndim == 5 for x in camera_embedding])
646
+ bs = camera_embedding[0].shape[0]
647
+ camera_embedding_features = []
648
+ for pe in camera_embedding:
649
+ camera_embedding_feature = self.camera_encoder(pe)
650
+ camera_embedding_feature = [rearrange(x, '(b f) c h w -> b c f h w', b=bs) for x in camera_embedding_feature]
651
+ camera_embedding_features.append(camera_embedding_feature)
652
+ else:
653
+ bs = camera_embedding.shape[0]
654
+ assert camera_embedding.ndim == 5
655
+ camera_embedding_features = self.camera_encoder(camera_embedding) # bf, c, h, w
656
+ camera_embedding_features = [rearrange(x, '(b f) c h w -> b c f h w', b=bs)
657
+ for x in camera_embedding_features]
658
+
659
+ # Denoising loop
660
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
661
+ if isinstance(camera_embedding_features[0], list):
662
+ camera_embedding_features = [[torch.cat([x, x], dim=0) for x in camera_embedding_feature]
663
+ for camera_embedding_feature in camera_embedding_features] \
664
+ if do_classifier_free_guidance else camera_embedding_features
665
+ else:
666
+ camera_embedding_features = [torch.cat([x, x], dim=0) for x in camera_embedding_features] \
667
+ if do_classifier_free_guidance else camera_embedding_features # [2b c f h w]
668
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
669
+ for i, t in enumerate(timesteps):
670
+ noise_pred_full = torch.zeros_like(latents).to(latents.device)
671
+ mask_full = torch.zeros_like(latents).to(latents.device)
672
+ noise_preds = []
673
+ for multidiff_step in range(multidiff_total_steps):
674
+ start_idx = multidiff_step * (single_model_length - multidiff_overlaps)
675
+ latent_partial = latents[:, :, start_idx: start_idx + single_model_length].contiguous()
676
+ mask_full[:, :, start_idx: start_idx + single_model_length] += 1
677
+
678
+ if isinstance(camera_embedding, list):
679
+ camera_embedding_features_input = camera_embedding_features[multidiff_step]
680
+ else:
681
+ camera_embedding_features_input = [x[:, :, start_idx: start_idx + single_model_length]
682
+ for x in camera_embedding_features]
683
+
684
+ # expand the latents if we are doing classifier free guidance
685
+ latent_model_input = torch.cat([latent_partial] * 2) if do_classifier_free_guidance else latent_partial # [2b c f h w]
686
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
687
+
688
+ # predict the noise residual
689
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings,
690
+ camera_embedding_features=camera_embedding_features_input).sample.to(dtype=latents_dtype)
691
+ # perform guidance
692
+ if do_classifier_free_guidance:
693
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
694
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
695
+ noise_preds.append(noise_pred)
696
+ for pred_idx, noise_pred in enumerate(noise_preds):
697
+ start_idx = pred_idx * (single_model_length - multidiff_overlaps)
698
+ noise_pred_full[:, :, start_idx: start_idx + single_model_length] += noise_pred / mask_full[:, :, start_idx: start_idx + single_model_length]
699
+
700
+ # compute the previous noisy sample x_t -> x_t-1 b c f h w
701
+ latents = self.scheduler.step(noise_pred_full, t, latents, **extra_step_kwargs).prev_sample
702
+
703
+ # call the callback, if provided
704
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
705
+ progress_bar.update()
706
+ if callback is not None and i % callback_steps == 0:
707
+ callback(i, t, latents)
708
+
709
+ # Post-processing
710
+ video = self.decode_latents(latents)
711
+
712
+ # Convert to tensor
713
+ if output_type == "tensor":
714
+ video = torch.from_numpy(video)
715
+
716
+ if not return_dict:
717
+ return video
718
+
719
+ return AnimationPipelineOutput(videos=video)
genphoto/utils/convert_from_ckpt.py CHANGED
@@ -1,3 +1,556 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3ca60e78e034ed48ea1b7d48c09d2707940b1e25b749ee68bb6b601a96270435
3
- size 25125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Conversion script for the Stable Diffusion checkpoints."""
16
+
17
+ import re
18
+ from transformers import CLIPTextModel
19
+
20
+ def shave_segments(path, n_shave_prefix_segments=1):
21
+ """
22
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
23
+ """
24
+ if n_shave_prefix_segments >= 0:
25
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
26
+ else:
27
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
28
+
29
+
30
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
31
+ """
32
+ Updates paths inside resnets to the new naming scheme (local renaming)
33
+ """
34
+ mapping = []
35
+ for old_item in old_list:
36
+ new_item = old_item.replace("in_layers.0", "norm1")
37
+ new_item = new_item.replace("in_layers.2", "conv1")
38
+
39
+ new_item = new_item.replace("out_layers.0", "norm2")
40
+ new_item = new_item.replace("out_layers.3", "conv2")
41
+
42
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
43
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
44
+
45
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
46
+
47
+ mapping.append({"old": old_item, "new": new_item})
48
+
49
+ return mapping
50
+
51
+
52
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
53
+ """
54
+ Updates paths inside resnets to the new naming scheme (local renaming)
55
+ """
56
+ mapping = []
57
+ for old_item in old_list:
58
+ new_item = old_item
59
+
60
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
61
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
62
+
63
+ mapping.append({"old": old_item, "new": new_item})
64
+
65
+ return mapping
66
+
67
+
68
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
69
+ """
70
+ Updates paths inside attentions to the new naming scheme (local renaming)
71
+ """
72
+ mapping = []
73
+ for old_item in old_list:
74
+ new_item = old_item
75
+
76
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
77
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
78
+
79
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
80
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
81
+
82
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
83
+
84
+ mapping.append({"old": old_item, "new": new_item})
85
+
86
+ return mapping
87
+
88
+
89
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
90
+ """
91
+ Updates paths inside attentions to the new naming scheme (local renaming)
92
+ """
93
+ mapping = []
94
+ for old_item in old_list:
95
+ new_item = old_item
96
+
97
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
98
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
99
+
100
+ new_item = new_item.replace("q.weight", "query.weight")
101
+ new_item = new_item.replace("q.bias", "query.bias")
102
+
103
+ new_item = new_item.replace("k.weight", "key.weight")
104
+ new_item = new_item.replace("k.bias", "key.bias")
105
+
106
+ new_item = new_item.replace("v.weight", "value.weight")
107
+ new_item = new_item.replace("v.bias", "value.bias")
108
+
109
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
110
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
111
+
112
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
113
+
114
+ mapping.append({"old": old_item, "new": new_item})
115
+
116
+ return mapping
117
+
118
+
119
+ def assign_to_checkpoint(
120
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
121
+ ):
122
+ """
123
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
124
+ attention layers, and takes into account additional replacements that may arise.
125
+
126
+ Assigns the weights to the new checkpoint.
127
+ """
128
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
129
+
130
+ # Splits the attention layers into three variables.
131
+ if attention_paths_to_split is not None:
132
+ for path, path_map in attention_paths_to_split.items():
133
+ old_tensor = old_checkpoint[path]
134
+ channels = old_tensor.shape[0] // 3
135
+
136
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
137
+
138
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
139
+
140
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
141
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
142
+
143
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
144
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
145
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
146
+
147
+ for path in paths:
148
+ new_path = path["new"]
149
+
150
+ # These have already been assigned
151
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
152
+ continue
153
+
154
+ # Global renaming happens here
155
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
156
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
157
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
158
+
159
+ if additional_replacements is not None:
160
+ for replacement in additional_replacements:
161
+ new_path = new_path.replace(replacement["old"], replacement["new"])
162
+
163
+ # proj_attn.weight has to be converted from conv 1D to linear
164
+ if "proj_attn.weight" in new_path:
165
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
166
+ else:
167
+ checkpoint[new_path] = old_checkpoint[path["old"]]
168
+
169
+
170
+ def conv_attn_to_linear(checkpoint):
171
+ keys = list(checkpoint.keys())
172
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
173
+ for key in keys:
174
+ if ".".join(key.split(".")[-2:]) in attn_keys:
175
+ if checkpoint[key].ndim > 2:
176
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
177
+ elif "proj_attn.weight" in key:
178
+ if checkpoint[key].ndim > 2:
179
+ checkpoint[key] = checkpoint[key][:, :, 0]
180
+
181
+
182
+ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
183
+ """
184
+ Takes a state dict and a config, and returns a converted checkpoint.
185
+ """
186
+
187
+ # extract state_dict for UNet
188
+ unet_state_dict = {}
189
+ keys = list(checkpoint.keys())
190
+
191
+ if controlnet:
192
+ unet_key = "control_model."
193
+ else:
194
+ unet_key = "model.diffusion_model."
195
+
196
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
197
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
198
+ print(f"Checkpoint {path} has both EMA and non-EMA weights.")
199
+ print(
200
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
201
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
202
+ )
203
+ for key in keys:
204
+ if key.startswith("model.diffusion_model"):
205
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
206
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
207
+ else:
208
+ if sum(k.startswith("model_ema") for k in keys) > 100:
209
+ print(
210
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
211
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
212
+ )
213
+
214
+ for key in keys:
215
+ if key.startswith(unet_key):
216
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
217
+
218
+ new_checkpoint = {}
219
+
220
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
221
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
222
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
223
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
224
+
225
+ if config["class_embed_type"] is None:
226
+ # No parameters to port
227
+ ...
228
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
229
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
230
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
231
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
232
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
233
+ else:
234
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
235
+
236
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
237
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
238
+
239
+ if not controlnet:
240
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
241
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
242
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
243
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
244
+
245
+ # Retrieves the keys for the input blocks only
246
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
247
+ input_blocks = {
248
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
249
+ for layer_id in range(num_input_blocks)
250
+ }
251
+
252
+ # Retrieves the keys for the middle blocks only
253
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
254
+ middle_blocks = {
255
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
256
+ for layer_id in range(num_middle_blocks)
257
+ }
258
+
259
+ # Retrieves the keys for the output blocks only
260
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
261
+ output_blocks = {
262
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
263
+ for layer_id in range(num_output_blocks)
264
+ }
265
+
266
+ for i in range(1, num_input_blocks):
267
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
268
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
269
+
270
+ resnets = [
271
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
272
+ ]
273
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
274
+
275
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
276
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
277
+ f"input_blocks.{i}.0.op.weight"
278
+ )
279
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
280
+ f"input_blocks.{i}.0.op.bias"
281
+ )
282
+
283
+ paths = renew_resnet_paths(resnets)
284
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
285
+ assign_to_checkpoint(
286
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
287
+ )
288
+
289
+ if len(attentions):
290
+ paths = renew_attention_paths(attentions)
291
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
292
+ assign_to_checkpoint(
293
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
294
+ )
295
+
296
+ resnet_0 = middle_blocks[0]
297
+ attentions = middle_blocks[1]
298
+ resnet_1 = middle_blocks[2]
299
+
300
+ resnet_0_paths = renew_resnet_paths(resnet_0)
301
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
302
+
303
+ resnet_1_paths = renew_resnet_paths(resnet_1)
304
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
305
+
306
+ attentions_paths = renew_attention_paths(attentions)
307
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
308
+ assign_to_checkpoint(
309
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
310
+ )
311
+
312
+ for i in range(num_output_blocks):
313
+ block_id = i // (config["layers_per_block"] + 1)
314
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
315
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
316
+ output_block_list = {}
317
+
318
+ for layer in output_block_layers:
319
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
320
+ if layer_id in output_block_list:
321
+ output_block_list[layer_id].append(layer_name)
322
+ else:
323
+ output_block_list[layer_id] = [layer_name]
324
+
325
+ if len(output_block_list) > 1:
326
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
327
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
328
+
329
+ resnet_0_paths = renew_resnet_paths(resnets)
330
+ paths = renew_resnet_paths(resnets)
331
+
332
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
333
+ assign_to_checkpoint(
334
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
335
+ )
336
+
337
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
338
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
339
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
340
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
341
+ f"output_blocks.{i}.{index}.conv.weight"
342
+ ]
343
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
344
+ f"output_blocks.{i}.{index}.conv.bias"
345
+ ]
346
+
347
+ # Clear attentions as they have been attributed above.
348
+ if len(attentions) == 2:
349
+ attentions = []
350
+
351
+ if len(attentions):
352
+ paths = renew_attention_paths(attentions)
353
+ meta_path = {
354
+ "old": f"output_blocks.{i}.1",
355
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
356
+ }
357
+ assign_to_checkpoint(
358
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
359
+ )
360
+ else:
361
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
362
+ for path in resnet_0_paths:
363
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
364
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
365
+
366
+ new_checkpoint[new_path] = unet_state_dict[old_path]
367
+
368
+ if controlnet:
369
+ # conditioning embedding
370
+
371
+ orig_index = 0
372
+
373
+ new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
374
+ f"input_hint_block.{orig_index}.weight"
375
+ )
376
+ new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
377
+ f"input_hint_block.{orig_index}.bias"
378
+ )
379
+
380
+ orig_index += 2
381
+
382
+ diffusers_index = 0
383
+
384
+ while diffusers_index < 6:
385
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
386
+ f"input_hint_block.{orig_index}.weight"
387
+ )
388
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
389
+ f"input_hint_block.{orig_index}.bias"
390
+ )
391
+ diffusers_index += 1
392
+ orig_index += 2
393
+
394
+ new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
395
+ f"input_hint_block.{orig_index}.weight"
396
+ )
397
+ new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
398
+ f"input_hint_block.{orig_index}.bias"
399
+ )
400
+
401
+ # down blocks
402
+ for i in range(num_input_blocks):
403
+ new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
404
+ new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
405
+
406
+ # mid block
407
+ new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
408
+ new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
409
+
410
+ return new_checkpoint
411
+
412
+
413
+ def convert_ldm_vae_checkpoint(checkpoint, config):
414
+ # extract state dict for VAE
415
+ vae_state_dict = {}
416
+ keys = list(checkpoint.keys())
417
+ vae_key = "first_stage_model." if any(k.startswith("first_stage_model.") for k in keys) else ""
418
+ for key in keys:
419
+ if key.startswith(vae_key):
420
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
421
+
422
+ new_checkpoint = {}
423
+
424
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
425
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
426
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
427
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
428
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
429
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
430
+
431
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
432
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
433
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
434
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
435
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
436
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
437
+
438
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
439
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
440
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
441
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
442
+
443
+ # Retrieves the keys for the encoder down blocks only
444
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
445
+ down_blocks = {
446
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
447
+ }
448
+
449
+ # Retrieves the keys for the decoder up blocks only
450
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
451
+ up_blocks = {
452
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
453
+ }
454
+
455
+ for i in range(num_down_blocks):
456
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
457
+
458
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
459
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
460
+ f"encoder.down.{i}.downsample.conv.weight"
461
+ )
462
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
463
+ f"encoder.down.{i}.downsample.conv.bias"
464
+ )
465
+
466
+ paths = renew_vae_resnet_paths(resnets)
467
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
468
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
469
+
470
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
471
+ num_mid_res_blocks = 2
472
+ for i in range(1, num_mid_res_blocks + 1):
473
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
474
+
475
+ paths = renew_vae_resnet_paths(resnets)
476
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
477
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
478
+
479
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
480
+ paths = renew_vae_attention_paths(mid_attentions)
481
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
482
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
483
+ conv_attn_to_linear(new_checkpoint)
484
+
485
+ for i in range(num_up_blocks):
486
+ block_id = num_up_blocks - 1 - i
487
+ resnets = [
488
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
489
+ ]
490
+
491
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
492
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
493
+ f"decoder.up.{block_id}.upsample.conv.weight"
494
+ ]
495
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
496
+ f"decoder.up.{block_id}.upsample.conv.bias"
497
+ ]
498
+
499
+ paths = renew_vae_resnet_paths(resnets)
500
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
501
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
502
+
503
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
504
+ num_mid_res_blocks = 2
505
+ for i in range(1, num_mid_res_blocks + 1):
506
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
507
+
508
+ paths = renew_vae_resnet_paths(resnets)
509
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
510
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
511
+
512
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
513
+ paths = renew_vae_attention_paths(mid_attentions)
514
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
515
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
516
+ conv_attn_to_linear(new_checkpoint)
517
+ return new_checkpoint
518
+
519
+
520
+ def convert_ldm_clip_checkpoint(checkpoint):
521
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
522
+ keys = list(checkpoint.keys())
523
+
524
+ text_model_dict = {}
525
+
526
+ for key in keys:
527
+ if key.startswith("cond_stage_model.transformer"):
528
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
529
+
530
+ text_model.load_state_dict(text_model_dict)
531
+
532
+ return text_model
533
+
534
+
535
+ textenc_conversion_lst = [
536
+ ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
537
+ ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
538
+ ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
539
+ ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
540
+ ]
541
+ textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
542
+
543
+ textenc_transformer_conversion_lst = [
544
+ # (stable-diffusion, HF Diffusers)
545
+ ("resblocks.", "text_model.encoder.layers."),
546
+ ("ln_1", "layer_norm1"),
547
+ ("ln_2", "layer_norm2"),
548
+ (".c_fc.", ".fc1."),
549
+ (".c_proj.", ".fc2."),
550
+ (".attn", ".self_attn"),
551
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
552
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
553
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
554
+ ]
555
+ protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
556
+ textenc_pattern = re.compile("|".join(protected.keys()))
genphoto/utils/convert_lora_safetensor_to_diffusers.py CHANGED
@@ -1,3 +1,154 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0c9162744237b045715cfe587c2be0117a49f538a99c1a853a2bf4c2d3695b69
3
- size 5981
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Conversion script for the LoRA's safetensors checkpoints. """
17
+
18
+ import argparse
19
+
20
+ import torch
21
+ from safetensors.torch import load_file
22
+
23
+ from diffusers import StableDiffusionPipeline
24
+ import pdb
25
+
26
+
27
+
28
+ def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0):
29
+ # directly update weight in diffusers model
30
+ for key in state_dict:
31
+ # only process lora down key
32
+ if "up." in key: continue
33
+
34
+ up_key = key.replace(".down.", ".up.")
35
+ model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
36
+ model_key = model_key.replace("to_out.", "to_out.0.")
37
+ layer_infos = model_key.split(".")[:-1]
38
+
39
+ curr_layer = pipeline.unet
40
+ while len(layer_infos) > 0:
41
+ temp_name = layer_infos.pop(0)
42
+ curr_layer = curr_layer.__getattr__(temp_name)
43
+
44
+ weight_down = state_dict[key]
45
+ weight_up = state_dict[up_key]
46
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
47
+
48
+ return pipeline
49
+
50
+
51
+
52
+ def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
53
+ # load base model
54
+ # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
55
+
56
+ # load LoRA weight from .safetensors
57
+ # state_dict = load_file(checkpoint_path)
58
+
59
+ visited = []
60
+
61
+ # directly update weight in diffusers model
62
+ for key in state_dict:
63
+ # it is suggested to print out the key, it usually will be something like below
64
+ # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
65
+
66
+ # as we have set the alpha beforehand, so just skip
67
+ if ".alpha" in key or key in visited:
68
+ continue
69
+
70
+ if "text" in key:
71
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
72
+ curr_layer = pipeline.text_encoder
73
+ else:
74
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
75
+ curr_layer = pipeline.unet
76
+
77
+ # find the target layer
78
+ temp_name = layer_infos.pop(0)
79
+ while len(layer_infos) > -1:
80
+ try:
81
+ curr_layer = curr_layer.__getattr__(temp_name)
82
+ if len(layer_infos) > 0:
83
+ temp_name = layer_infos.pop(0)
84
+ elif len(layer_infos) == 0:
85
+ break
86
+ except Exception:
87
+ if len(temp_name) > 0:
88
+ temp_name += "_" + layer_infos.pop(0)
89
+ else:
90
+ temp_name = layer_infos.pop(0)
91
+
92
+ pair_keys = []
93
+ if "lora_down" in key:
94
+ pair_keys.append(key.replace("lora_down", "lora_up"))
95
+ pair_keys.append(key)
96
+ else:
97
+ pair_keys.append(key)
98
+ pair_keys.append(key.replace("lora_up", "lora_down"))
99
+
100
+ # update weight
101
+ if len(state_dict[pair_keys[0]].shape) == 4:
102
+ weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
103
+ weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
104
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
105
+ else:
106
+ weight_up = state_dict[pair_keys[0]].to(torch.float32)
107
+ weight_down = state_dict[pair_keys[1]].to(torch.float32)
108
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
109
+
110
+ # update visited list
111
+ for item in pair_keys:
112
+ visited.append(item)
113
+
114
+ return pipeline
115
+
116
+
117
+ if __name__ == "__main__":
118
+ parser = argparse.ArgumentParser()
119
+
120
+ parser.add_argument(
121
+ "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
122
+ )
123
+ parser.add_argument(
124
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
125
+ )
126
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
127
+ parser.add_argument(
128
+ "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
129
+ )
130
+ parser.add_argument(
131
+ "--lora_prefix_text_encoder",
132
+ default="lora_te",
133
+ type=str,
134
+ help="The prefix of text encoder weight in safetensors",
135
+ )
136
+ parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
137
+ parser.add_argument(
138
+ "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
139
+ )
140
+ parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
141
+
142
+ args = parser.parse_args()
143
+
144
+ base_model_path = args.base_model_path
145
+ checkpoint_path = args.checkpoint_path
146
+ dump_path = args.dump_path
147
+ lora_prefix_unet = args.lora_prefix_unet
148
+ lora_prefix_text_encoder = args.lora_prefix_text_encoder
149
+ alpha = args.alpha
150
+
151
+ pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
152
+
153
+ pipe = pipe.to(args.device)
154
+ pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
genphoto/utils/util.py CHANGED
@@ -1,3 +1,148 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:cb53dbb7da4c905c1a68d9f74d5ac1e01ea13e82a2506117bc3d3436109bb1b4
3
- size 4875
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import functools
3
+ import logging
4
+ import sys
5
+ import imageio
6
+ import atexit
7
+ import importlib
8
+ import torch
9
+ import torchvision
10
+ import numpy as np
11
+ from termcolor import colored
12
+
13
+ from einops import rearrange
14
+
15
+
16
+ def instantiate_from_config(config, **additional_kwargs):
17
+ if not "target" in config:
18
+ if config == '__is_first_stage__':
19
+ return None
20
+ elif config == "__is_unconditional__":
21
+ return None
22
+ raise KeyError("Expected key `target` to instantiate.")
23
+
24
+ additional_kwargs.update(config.get("kwargs", dict()))
25
+ return get_obj_from_str(config["target"])(**additional_kwargs)
26
+
27
+
28
+ def get_obj_from_str(string, reload=False):
29
+ module, cls = string.rsplit(".", 1)
30
+ if reload:
31
+ module_imp = importlib.import_module(module)
32
+ importlib.reload(module_imp)
33
+ return getattr(importlib.import_module(module, package=None), cls)
34
+
35
+
36
+ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, fps=8):
37
+ videos = rearrange(videos, "b c t h w -> t b c h w")
38
+ outputs = []
39
+ for x in videos:
40
+ x = torchvision.utils.make_grid(x, nrow=n_rows)
41
+ x = x.transpose(0, 1).transpose(1, 2).squeeze(-1)
42
+ if rescale:
43
+ x = (x + 1.0) / 2.0 # -1,1 -> 0,1
44
+ x = (x * 255).numpy().astype(np.uint8)
45
+ outputs.append(x)
46
+
47
+ os.makedirs(os.path.dirname(path), exist_ok=True)
48
+ imageio.mimsave(path, outputs, fps=fps)
49
+
50
+
51
+ # Logger utils are copied from detectron2
52
+ class _ColorfulFormatter(logging.Formatter):
53
+ def __init__(self, *args, **kwargs):
54
+ self._root_name = kwargs.pop("root_name") + "."
55
+ self._abbrev_name = kwargs.pop("abbrev_name", "")
56
+ if len(self._abbrev_name):
57
+ self._abbrev_name = self._abbrev_name + "."
58
+ super(_ColorfulFormatter, self).__init__(*args, **kwargs)
59
+
60
+ def formatMessage(self, record):
61
+ record.name = record.name.replace(self._root_name, self._abbrev_name)
62
+ log = super(_ColorfulFormatter, self).formatMessage(record)
63
+ if record.levelno == logging.WARNING:
64
+ prefix = colored("WARNING", "red", attrs=["blink"])
65
+ elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
66
+ prefix = colored("ERROR", "red", attrs=["blink", "underline"])
67
+ else:
68
+ return log
69
+ return prefix + " " + log
70
+
71
+
72
+ # cache the opened file object, so that different calls to `setup_logger`
73
+ # with the same file name can safely write to the same file.
74
+ @functools.lru_cache(maxsize=None)
75
+ def _cached_log_stream(filename):
76
+ # use 1K buffer if writing to cloud storage
77
+ io = open(filename, "a", buffering=1024 if "://" in filename else -1)
78
+ atexit.register(io.close)
79
+ return io
80
+
81
+ @functools.lru_cache()
82
+ def setup_logger(output, distributed_rank, color=True, name='AnimateDiff', abbrev_name=None):
83
+ logger = logging.getLogger(name)
84
+ logger.setLevel(logging.DEBUG)
85
+ logger.propagate = False
86
+
87
+ if abbrev_name is None:
88
+ abbrev_name = 'AD'
89
+ plain_formatter = logging.Formatter(
90
+ "[%(asctime)s] %(name)s:%(lineno)d %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
91
+ )
92
+
93
+ # stdout logging: master only
94
+ if distributed_rank == 0:
95
+ ch = logging.StreamHandler(stream=sys.stdout)
96
+ ch.setLevel(logging.DEBUG)
97
+ if color:
98
+ formatter = _ColorfulFormatter(
99
+ colored("[%(asctime)s %(name)s:%(lineno)d]: ", "green") + "%(message)s",
100
+ datefmt="%m/%d %H:%M:%S",
101
+ root_name=name,
102
+ abbrev_name=str(abbrev_name),
103
+ )
104
+ else:
105
+ formatter = plain_formatter
106
+ ch.setFormatter(formatter)
107
+ logger.addHandler(ch)
108
+
109
+ # file logging: all workers
110
+ if output is not None:
111
+ if output.endswith(".txt") or output.endswith(".log"):
112
+ filename = output
113
+ else:
114
+ filename = os.path.join(output, "log.txt")
115
+ if distributed_rank > 0:
116
+ filename = filename + ".rank{}".format(distributed_rank)
117
+ os.makedirs(os.path.dirname(filename), exist_ok=True)
118
+
119
+ fh = logging.StreamHandler(_cached_log_stream(filename))
120
+ fh.setLevel(logging.DEBUG)
121
+ fh.setFormatter(plain_formatter)
122
+ logger.addHandler(fh)
123
+
124
+ return logger
125
+
126
+
127
+ def format_time(elapsed_time):
128
+ # Time thresholds
129
+ minute = 60
130
+ hour = 60 * minute
131
+ day = 24 * hour
132
+
133
+ days, remainder = divmod(elapsed_time, day)
134
+ hours, remainder = divmod(remainder, hour)
135
+ minutes, seconds = divmod(remainder, minute)
136
+
137
+ formatted_time = ""
138
+
139
+ if days > 0:
140
+ formatted_time += f"{int(days)} days "
141
+ if hours > 0:
142
+ formatted_time += f"{int(hours)} hours "
143
+ if minutes > 0:
144
+ formatted_time += f"{int(minutes)} minutes "
145
+ if seconds > 0:
146
+ formatted_time += f"{seconds:.2f} seconds"
147
+
148
+ return formatted_time.strip()
inference_bokehK.py CHANGED
@@ -1,3 +1,216 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:d71226af9e998b6f458bf837712b9ebcffab037ac09c29ea48742ba4d832b257
3
- size 8968
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import imageio
3
+ import os
4
+ import torch
5
+ import logging
6
+ import argparse
7
+ import json
8
+ import numpy as np
9
+ import torch.nn.functional as F
10
+ from pathlib import Path
11
+ from omegaconf import OmegaConf
12
+ from torch.utils.data import Dataset
13
+ from transformers import CLIPTextModel, CLIPTokenizer
14
+ from diffusers import AutoencoderKL, DDIMScheduler
15
+ from einops import rearrange
16
+
17
+ from genphoto.pipelines.pipeline_animation import GenPhotoPipeline
18
+ from genphoto.models.unet import UNet3DConditionModelCameraCond
19
+ from genphoto.models.camera_adaptor import CameraCameraEncoder, CameraAdaptor
20
+ from genphoto.utils.util import save_videos_grid
21
+
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ from huggingface_hub import hf_hub_download
27
+
28
+
29
+
30
+ def create_bokehK_embedding(bokehK_values, target_height, target_width):
31
+ f = bokehK_values.shape[0]
32
+ bokehK_embedding = torch.zeros((f, 3, target_height, target_width), dtype=bokehK_values.dtype)
33
+
34
+ for i in range(f):
35
+ K_value = bokehK_values[i].item()
36
+ kernel_size = max(K_value, 1)
37
+ sigma = K_value / 3.0
38
+
39
+ ax = np.linspace(-(kernel_size / 2), kernel_size / 2, int(np.ceil(kernel_size)))
40
+ xx, yy = np.meshgrid(ax, ax)
41
+ kernel = np.exp(-(xx ** 2 + yy ** 2) / (2 * sigma ** 2))
42
+ kernel /= np.sum(kernel)
43
+ scale = kernel[int(np.ceil(kernel_size) / 2), int(np.ceil(kernel_size) / 2)]
44
+
45
+ bokehK_embedding[i] = scale
46
+
47
+ return bokehK_embedding
48
+
49
+ class Camera_Embedding(Dataset):
50
+ def __init__(self, bokehK_values, tokenizer, text_encoder, device, sample_size=[256, 384]):
51
+ self.bokehK_values = bokehK_values.to(device)
52
+ self.tokenizer = tokenizer
53
+ self.text_encoder = text_encoder
54
+ self.device = device
55
+ self.sample_size = sample_size
56
+
57
+ def load(self):
58
+ if len(self.bokehK_values) != 5:
59
+ raise ValueError("Expected 5 bokehK values")
60
+
61
+ prompts = []
62
+ for bb in self.bokehK_values:
63
+ prompt = f"<bokeh kernel size: {bb.item()}>"
64
+ prompts.append(prompt)
65
+
66
+ with torch.no_grad():
67
+ prompt_ids = self.tokenizer(
68
+ prompts, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
69
+ ).input_ids.to(self.device)
70
+
71
+ encoder_hidden_states = self.text_encoder(input_ids=prompt_ids).last_hidden_state
72
+
73
+ differences = []
74
+ for i in range(1, encoder_hidden_states.size(0)):
75
+ diff = encoder_hidden_states[i] - encoder_hidden_states[i - 1]
76
+ diff = diff.unsqueeze(0)
77
+ differences.append(diff)
78
+
79
+ final_diff = encoder_hidden_states[-1] - encoder_hidden_states[0]
80
+ final_diff = final_diff.unsqueeze(0)
81
+ differences.append(final_diff)
82
+
83
+ concatenated_differences = torch.cat(differences, dim=0)
84
+
85
+ pad_length = 128 - concatenated_differences.size(1)
86
+ if pad_length > 0:
87
+ concatenated_differences_padded = F.pad(concatenated_differences, (0, 0, 0, pad_length))
88
+
89
+ ccl_embedding = concatenated_differences_padded.reshape(
90
+ concatenated_differences_padded.size(0), self.sample_size[0], self.sample_size[1]
91
+ ).unsqueeze(1).expand(-1, 3, -1, -1).to(self.device)
92
+
93
+ bokehK_embedding = create_bokehK_embedding(self.bokehK_values, self.sample_size[0], self.sample_size[1]).to(self.device)
94
+ camera_embedding = torch.cat((bokehK_embedding, ccl_embedding), dim=1)
95
+ return camera_embedding
96
+
97
+ def load_models(cfg):
98
+ device = "cuda" if torch.cuda.is_available() else "cpu"
99
+
100
+ pretrained_model_path = hf_hub_download("pandaphd/generative_photography", "stable-diffusion-v1-5/")
101
+ lora_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/RealEstate10K_LoRA.ckpt")
102
+ motion_module_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/v3_sd15_mm.ckpt")
103
+ camera_adaptor_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/checkpoint-bokehK.ckpt")
104
+
105
+ noise_scheduler = DDIMScheduler(**OmegaConf.to_container(cfg.noise_scheduler_kwargs))
106
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
107
+ vae.requires_grad_(False)
108
+
109
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
110
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device)
111
+ text_encoder.requires_grad_(False)
112
+
113
+ unet = UNet3DConditionModelCameraCond.from_pretrained_2d(
114
+ pretrained_model_path,
115
+ subfolder=cfg.unet_subfolder,
116
+ unet_additional_kwargs=cfg.unet_additional_kwargs
117
+ ).to(device)
118
+ unet.requires_grad_(False)
119
+
120
+ camera_encoder = CameraCameraEncoder(**cfg.camera_encoder_kwargs).to(device)
121
+ camera_encoder.requires_grad_(False)
122
+ camera_adaptor = CameraAdaptor(unet, camera_encoder)
123
+ camera_adaptor.requires_grad_(False)
124
+ camera_adaptor.to(device)
125
+
126
+ unet.set_all_attn_processor(
127
+ add_spatial_lora=cfg.lora_ckpt is not None,
128
+ add_motion_lora=cfg.motion_lora_rank > 0,
129
+ lora_kwargs={"lora_rank": cfg.lora_rank, "lora_scale": cfg.lora_scale},
130
+ motion_lora_kwargs={"lora_rank": cfg.motion_lora_rank, "lora_scale": cfg.motion_lora_scale},
131
+ **cfg.attention_processor_kwargs
132
+ )
133
+
134
+ if cfg.lora_ckpt is not None:
135
+ lora_checkpoints = torch.load(lora_ckpt_path, map_location=unet.device)
136
+ if 'lora_state_dict' in lora_checkpoints.keys():
137
+ lora_checkpoints = lora_checkpoints['lora_state_dict']
138
+ _, lora_u = unet.load_state_dict(lora_checkpoints, strict=False)
139
+ assert len(lora_u) == 0
140
+
141
+ if cfg.motion_module_ckpt is not None:
142
+ mm_checkpoints = torch.load(motion_module_ckpt_path, map_location=unet.device)
143
+ _, mm_u = unet.load_state_dict(mm_checkpoints, strict=False)
144
+ assert len(mm_u) == 0
145
+
146
+ if cfg.camera_adaptor_ckpt is not None:
147
+ camera_adaptor_checkpoint = torch.load(camera_adaptor_ckpt_path, map_location=device)
148
+ camera_encoder_state_dict = camera_adaptor_checkpoint['camera_encoder_state_dict']
149
+ attention_processor_state_dict = camera_adaptor_checkpoint['attention_processor_state_dict']
150
+ camera_enc_m, camera_enc_u = camera_adaptor.camera_encoder.load_state_dict(camera_encoder_state_dict, strict=False)
151
+ assert len(camera_enc_m) == 0 and len(camera_enc_u) == 0
152
+ _, attention_processor_u = camera_adaptor.unet.load_state_dict(attention_processor_state_dict, strict=False)
153
+ assert len(attention_processor_u) == 0
154
+
155
+ pipeline = GenPhotoPipeline(
156
+ vae=vae,
157
+ text_encoder=text_encoder,
158
+ tokenizer=tokenizer,
159
+ unet=unet,
160
+ scheduler=noise_scheduler,
161
+ camera_encoder=camera_encoder
162
+ ).to(device)
163
+
164
+ pipeline.enable_vae_slicing()
165
+ return pipeline, device
166
+
167
+ def run_inference(pipeline, tokenizer, text_encoder, base_scene, bokehK_list, device, video_length=5, height=256, width=384):
168
+
169
+
170
+ bokehK_values = json.loads(bokehK_list)
171
+ bokehK_values = torch.tensor(bokehK_values).unsqueeze(1)
172
+
173
+ camera_embedding = Camera_Embedding(bokehK_values, tokenizer, text_encoder, device).load()
174
+ camera_embedding = rearrange(camera_embedding.unsqueeze(0), "b f c h w -> b c f h w")
175
+
176
+ with torch.no_grad():
177
+ sample = pipeline(
178
+ prompt=base_scene,
179
+ camera_embedding=camera_embedding,
180
+ video_length=video_length,
181
+ height=height,
182
+ width=width,
183
+ num_inference_steps=25,
184
+ guidance_scale=8.0
185
+ ).videos[0].cpu()
186
+
187
+ temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
188
+ save_videos_grid(sample[None], temporal_video_path, rescale=False)
189
+
190
+
191
+ return temporal_video_path
192
+
193
+
194
+ def main(config_path, base_scene, bokehK_list):
195
+ torch.manual_seed(42)
196
+ cfg = OmegaConf.load(config_path)
197
+ logger.info("Loading models...")
198
+ pipeline, device = load_models(cfg)
199
+ logger.info("Starting inference...")
200
+
201
+ video_path = run_inference(pipeline, pipeline.tokenizer, pipeline.text_encoder, base_scene, bokehK_list, device)
202
+ logger.info(f"Video saved to {video_path}")
203
+
204
+
205
+ if __name__ == "__main__":
206
+ parser = argparse.ArgumentParser()
207
+ parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
208
+ parser.add_argument("--base_scene", type=str, required=True, help="Scene description")
209
+ parser.add_argument("--bokehK_list", type=str, required=True, help="Comma-separated Bokeh K values")
210
+ args = parser.parse_args()
211
+ main(args.config, args.base_scene, args.bokehK_list)
212
+
213
+
214
+ ## example
215
+ ## python inference_bokehK.py --config configs/inference_genphoto/adv3_256_384_genphoto_relora_bokehK.yaml --base_scene "A young boy wearing an orange jacket is standing on a crosswalk, waiting to cross the street." --bokehK_list "[2.44, 8.3, 10.1, 17.2, 24.0]"
216
+
inference_color_temperature.py CHANGED
@@ -1,3 +1,338 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6ed5fe8385e56e837fdb7c8ca21973136a42f4c3b09c6223c800dcc60955d61d
3
- size 14631
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import imageio
3
+ import os
4
+ import torch
5
+ import logging
6
+ import argparse
7
+ import json
8
+ import numpy as np
9
+ import torch.nn.functional as F
10
+ from pathlib import Path
11
+ from omegaconf import OmegaConf
12
+ from torch.utils.data import Dataset
13
+ from transformers import CLIPTextModel, CLIPTokenizer
14
+ from diffusers import AutoencoderKL, DDIMScheduler
15
+ from einops import rearrange
16
+
17
+ from genphoto.pipelines.pipeline_animation import GenPhotoPipeline
18
+ from genphoto.models.unet import UNet3DConditionModelCameraCond
19
+ from genphoto.models.camera_adaptor import CameraCameraEncoder, CameraAdaptor
20
+ from genphoto.utils.util import save_videos_grid
21
+
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+ from huggingface_hub import hf_hub_download
26
+
27
+
28
+ def kelvin_to_rgb(kelvin):
29
+ if torch.is_tensor(kelvin):
30
+ kelvin = kelvin.cpu().item()
31
+
32
+ temp = kelvin / 100.0
33
+
34
+ if temp <= 66:
35
+ red = 255
36
+ green = 99.4708025861 * np.log(temp) - 161.1195681661 if temp > 0 else 0
37
+ if temp <= 19:
38
+ blue = 0
39
+ else:
40
+ blue = 138.5177312231 * np.log(temp - 10) - 305.0447927307
41
+
42
+ elif 66 < temp <= 88:
43
+ red = 0.5 * (255 + 329.698727446 * ((temp - 60) ** -0.19332047592))
44
+ green = 0.5 * (288.1221695283 * ((temp - 60) ** -0.1155148492) +
45
+ (99.4708025861 * np.log(temp) - 161.1195681661 if temp > 0 else 0))
46
+ blue = 0.5 * (138.5177312231 * np.log(temp - 10) - 305.0447927307 + 255)
47
+
48
+ else:
49
+ red = 329.698727446 * ((temp - 60) ** -0.19332047592)
50
+ green = 288.1221695283 * ((temp - 60) ** -0.1155148492)
51
+ blue = 255
52
+
53
+ return np.array([red, green, blue], dtype=np.float32) / 255.0
54
+
55
+
56
+ def create_color_temperature_embedding(color_temperature_values, target_height, target_width, min_color_temperature=2000, max_color_temperature=10000):
57
+ f = color_temperature_values.shape[0]
58
+ rgb_factors = []
59
+
60
+ # Compute RGB factors based on kelvin_to_rgb function
61
+ for color_temperature in color_temperature_values.squeeze():
62
+ kelvin = min_color_temperature + (color_temperature * (max_color_temperature - min_color_temperature)) # Map normalized color_temperature to actual Kelvin
63
+ rgb = kelvin_to_rgb(kelvin)
64
+ rgb_factors.append(rgb)
65
+
66
+ # Convert to tensor and expand to target dimensions
67
+ rgb_factors = torch.tensor(rgb_factors).float() # [f, 3]
68
+ rgb_factors = rgb_factors.unsqueeze(2).unsqueeze(3) # [f, 3, 1, 1]
69
+ color_temperature_embedding = rgb_factors.expand(f, 3, target_height, target_width) # [f, 3, target_height, target_width]
70
+
71
+ return color_temperature_embedding
72
+
73
+
74
+
75
+ class Camera_Embedding(Dataset):
76
+ def __init__(self, color_temperature_values, tokenizer, text_encoder, device, sample_size=[256, 384]):
77
+ self.color_temperature_values = color_temperature_values.to(device)
78
+ self.tokenizer = tokenizer
79
+ self.text_encoder = text_encoder
80
+ self.device = device
81
+ self.sample_size = sample_size
82
+
83
+ def load(self):
84
+
85
+ if len(self.color_temperature_values) != 5:
86
+ raise ValueError("Expected 5 color_temperature values")
87
+
88
+ # Generate prompts for each color_temperature value and append color_temperature information to caption
89
+ prompts = []
90
+ for ct in self.color_temperature_values:
91
+ prompt = f"<color temperature: {ct.item()}>"
92
+ prompts.append(prompt)
93
+
94
+
95
+ # Tokenize prompts and encode to get embeddings
96
+ with torch.no_grad():
97
+ prompt_ids = self.tokenizer(
98
+ prompts, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
99
+ ).input_ids.to(self.device)
100
+
101
+ encoder_hidden_states = self.text_encoder(input_ids=prompt_ids).last_hidden_state # Shape: (f, sequence_length, hidden_size)
102
+
103
+
104
+ # Calculate differences between consecutive embeddings (ignoring sequence_length)
105
+ differences = []
106
+ for i in range(1, encoder_hidden_states.size(0)):
107
+ diff = encoder_hidden_states[i] - encoder_hidden_states[i - 1]
108
+ diff = diff.unsqueeze(0)
109
+ differences.append(diff)
110
+
111
+
112
+ # Add the difference between the last and the first embedding
113
+ final_diff = encoder_hidden_states[-1] - encoder_hidden_states[0]
114
+ final_diff = final_diff.unsqueeze(0)
115
+ differences.append(final_diff)
116
+
117
+ # Concatenate differences along the batch dimension (f-1)
118
+ concatenated_differences = torch.cat(differences, dim=0)
119
+ frame = concatenated_differences.size(0)
120
+ concatenated_differences = torch.cat(differences, dim=0)
121
+
122
+ pad_length = 128 - concatenated_differences.size(1)
123
+ if pad_length > 0:
124
+ concatenated_differences_padded = F.pad(concatenated_differences, (0, 0, 0, pad_length))
125
+
126
+
127
+ ccl_embedding = concatenated_differences_padded.reshape(frame, self.sample_size[0], self.sample_size[1])
128
+ ccl_embedding = ccl_embedding.unsqueeze(1)
129
+ ccl_embedding = ccl_embedding.expand(-1, 3, -1, -1)
130
+ ccl_embedding = ccl_embedding.to(self.device)
131
+ color_temperature_embedding = create_color_temperature_embedding(self.color_temperature_values, self.sample_size[0], self.sample_size[1]).to(self.device)
132
+ camera_embedding = torch.cat((color_temperature_embedding, ccl_embedding), dim=1)
133
+ return camera_embedding
134
+
135
+ #
136
+ # def load_models(cfg):
137
+ #
138
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
139
+ #
140
+ # noise_scheduler = DDIMScheduler(**OmegaConf.to_container(cfg.noise_scheduler_kwargs))
141
+ # vae = AutoencoderKL.from_pretrained(cfg.pretrained_model_path, subfolder="vae").to(device)
142
+ # vae.requires_grad_(False)
143
+ # tokenizer = CLIPTokenizer.from_pretrained(cfg.pretrained_model_path, subfolder="tokenizer")
144
+ # text_encoder = CLIPTextModel.from_pretrained(cfg.pretrained_model_path, subfolder="text_encoder").to(device)
145
+ # text_encoder.requires_grad_(False)
146
+ # unet = UNet3DConditionModelCameraCond.from_pretrained_2d(
147
+ # cfg.pretrained_model_path,
148
+ # subfolder=cfg.unet_subfolder,
149
+ # unet_additional_kwargs=cfg.unet_additional_kwargs
150
+ # ).to(device)
151
+ # unet.requires_grad_(False)
152
+ #
153
+ # camera_encoder = CameraCameraEncoder(**cfg.camera_encoder_kwargs).to(device)
154
+ # camera_encoder.requires_grad_(False)
155
+ # camera_adaptor = CameraAdaptor(unet, camera_encoder)
156
+ # camera_adaptor.requires_grad_(False)
157
+ # camera_adaptor.to(device)
158
+ #
159
+ # logger.info("Setting the attention processors")
160
+ # unet.set_all_attn_processor(
161
+ # add_spatial_lora=cfg.lora_ckpt is not None,
162
+ # add_motion_lora=cfg.motion_lora_rank > 0,
163
+ # lora_kwargs={"lora_rank": cfg.lora_rank, "lora_scale": cfg.lora_scale},
164
+ # motion_lora_kwargs={"lora_rank": cfg.motion_lora_rank, "lora_scale": cfg.motion_lora_scale},
165
+ # **cfg.attention_processor_kwargs
166
+ # )
167
+ #
168
+ # if cfg.lora_ckpt is not None:
169
+ # print(f"Loading the lora checkpoint from {cfg.lora_ckpt}")
170
+ # lora_checkpoints = torch.load(cfg.lora_ckpt, map_location=unet.device)
171
+ # if 'lora_state_dict' in lora_checkpoints.keys():
172
+ # lora_checkpoints = lora_checkpoints['lora_state_dict']
173
+ # _, lora_u = unet.load_state_dict(lora_checkpoints, strict=False)
174
+ # assert len(lora_u) == 0
175
+ # print(f'Loading done')
176
+ #
177
+ # if cfg.motion_module_ckpt is not None:
178
+ # print(f"Loading the motion module checkpoint from {cfg.motion_module_ckpt}")
179
+ # mm_checkpoints = torch.load(cfg.motion_module_ckpt, map_location=unet.device)
180
+ # _, mm_u = unet.load_state_dict(mm_checkpoints, strict=False)
181
+ # assert len(mm_u) == 0
182
+ # print("Loading done")
183
+ #
184
+ #
185
+ # if cfg.camera_adaptor_ckpt is not None:
186
+ # logger.info(f"Loading camera adaptor from {cfg.camera_adaptor_ckpt}")
187
+ # camera_adaptor_checkpoint = torch.load(cfg.camera_adaptor_ckpt, map_location=device)
188
+ # camera_encoder_state_dict = camera_adaptor_checkpoint['camera_encoder_state_dict']
189
+ # attention_processor_state_dict = camera_adaptor_checkpoint['attention_processor_state_dict']
190
+ # camera_enc_m, camera_enc_u = camera_adaptor.camera_encoder.load_state_dict(camera_encoder_state_dict, strict=False)
191
+ #
192
+ # assert len(camera_enc_m) == 0 and len(camera_enc_u) == 0
193
+ # _, attention_processor_u = camera_adaptor.unet.load_state_dict(attention_processor_state_dict, strict=False)
194
+ # assert len(attention_processor_u) == 0
195
+ #
196
+ # logger.info("Camera Adaptor loading done")
197
+ # else:
198
+ # logger.info("No Camera Adaptor checkpoint used")
199
+ #
200
+ # pipeline = GenPhotoPipeline(
201
+ # vae=vae,
202
+ # text_encoder=text_encoder,
203
+ # tokenizer=tokenizer,
204
+ # unet=unet,
205
+ # scheduler=noise_scheduler,
206
+ # camera_encoder=camera_encoder
207
+ # ).to(device)
208
+ #
209
+ # pipeline.enable_vae_slicing()
210
+ #
211
+ # return pipeline, device
212
+
213
+
214
+
215
+ def load_models(cfg):
216
+ device = "cuda" if torch.cuda.is_available() else "cpu"
217
+
218
+ pretrained_model_path = hf_hub_download("pandaphd/generative_photography", "stable-diffusion-v1-5/")
219
+ lora_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/RealEstate10K_LoRA.ckpt")
220
+ motion_module_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/v3_sd15_mm.ckpt")
221
+ camera_adaptor_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/checkpoint-color_temperature.ckpt")
222
+
223
+ noise_scheduler = DDIMScheduler(**OmegaConf.to_container(cfg.noise_scheduler_kwargs))
224
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
225
+ vae.requires_grad_(False)
226
+
227
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
228
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device)
229
+ text_encoder.requires_grad_(False)
230
+
231
+ unet = UNet3DConditionModelCameraCond.from_pretrained_2d(
232
+ pretrained_model_path,
233
+ subfolder=cfg.unet_subfolder,
234
+ unet_additional_kwargs=cfg.unet_additional_kwargs
235
+ ).to(device)
236
+ unet.requires_grad_(False)
237
+
238
+ camera_encoder = CameraCameraEncoder(**cfg.camera_encoder_kwargs).to(device)
239
+ camera_encoder.requires_grad_(False)
240
+ camera_adaptor = CameraAdaptor(unet, camera_encoder)
241
+ camera_adaptor.requires_grad_(False)
242
+ camera_adaptor.to(device)
243
+
244
+ unet.set_all_attn_processor(
245
+ add_spatial_lora=cfg.lora_ckpt is not None,
246
+ add_motion_lora=cfg.motion_lora_rank > 0,
247
+ lora_kwargs={"lora_rank": cfg.lora_rank, "lora_scale": cfg.lora_scale},
248
+ motion_lora_kwargs={"lora_rank": cfg.motion_lora_rank, "lora_scale": cfg.motion_lora_scale},
249
+ **cfg.attention_processor_kwargs
250
+ )
251
+
252
+ if cfg.lora_ckpt is not None:
253
+ lora_checkpoints = torch.load(lora_ckpt_path, map_location=unet.device)
254
+ if 'lora_state_dict' in lora_checkpoints.keys():
255
+ lora_checkpoints = lora_checkpoints['lora_state_dict']
256
+ _, lora_u = unet.load_state_dict(lora_checkpoints, strict=False)
257
+ assert len(lora_u) == 0
258
+
259
+ if cfg.motion_module_ckpt is not None:
260
+ mm_checkpoints = torch.load(motion_module_ckpt_path, map_location=unet.device)
261
+ _, mm_u = unet.load_state_dict(mm_checkpoints, strict=False)
262
+ assert len(mm_u) == 0
263
+
264
+ if cfg.camera_adaptor_ckpt is not None:
265
+ camera_adaptor_checkpoint = torch.load(camera_adaptor_ckpt_path, map_location=device)
266
+ camera_encoder_state_dict = camera_adaptor_checkpoint['camera_encoder_state_dict']
267
+ attention_processor_state_dict = camera_adaptor_checkpoint['attention_processor_state_dict']
268
+ camera_enc_m, camera_enc_u = camera_adaptor.camera_encoder.load_state_dict(camera_encoder_state_dict, strict=False)
269
+ assert len(camera_enc_m) == 0 and len(camera_enc_u) == 0
270
+ _, attention_processor_u = camera_adaptor.unet.load_state_dict(attention_processor_state_dict, strict=False)
271
+ assert len(attention_processor_u) == 0
272
+
273
+ pipeline = GenPhotoPipeline(
274
+ vae=vae,
275
+ text_encoder=text_encoder,
276
+ tokenizer=tokenizer,
277
+ unet=unet,
278
+ scheduler=noise_scheduler,
279
+ camera_encoder=camera_encoder
280
+ ).to(device)
281
+
282
+ pipeline.enable_vae_slicing()
283
+ return pipeline, device
284
+
285
+
286
+
287
+
288
+ def run_inference(pipeline, tokenizer, text_encoder, base_scene, color_temperature_list, device, video_length=5, height=256, width=384):
289
+
290
+ color_temperature_values = json.loads(color_temperature_list)
291
+ color_temperature_values = torch.tensor(color_temperature_values).unsqueeze(1)
292
+
293
+ # Ensure camera_embedding is on the correct device
294
+ camera_embedding = Camera_Embedding(color_temperature_values, tokenizer, text_encoder, device).load()
295
+ camera_embedding = rearrange(camera_embedding.unsqueeze(0), "b f c h w -> b c f h w")
296
+
297
+ with torch.no_grad():
298
+ sample = pipeline(
299
+ prompt=base_scene,
300
+ camera_embedding=camera_embedding,
301
+ video_length=video_length,
302
+ height=height,
303
+ width=width,
304
+ num_inference_steps=25,
305
+ guidance_scale=8.0
306
+ ).videos[0].cpu()
307
+
308
+ temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
309
+ save_videos_grid(sample[None], temporal_video_path, rescale=False)
310
+
311
+
312
+ return temporal_video_path
313
+
314
+
315
+ def main(config_path, base_scene, color_temperature_list):
316
+ torch.manual_seed(42)
317
+ cfg = OmegaConf.load(config_path)
318
+ logger.info("Loading models...")
319
+ pipeline, device = load_models(cfg)
320
+ logger.info("Starting inference...")
321
+
322
+
323
+ video_path = run_inference(pipeline, pipeline.tokenizer, pipeline.text_encoder, base_scene, color_temperature_list, device)
324
+ logger.info(f"Video saved to {video_path}")
325
+
326
+
327
+ if __name__ == "__main__":
328
+ parser = argparse.ArgumentParser()
329
+ parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
330
+ parser.add_argument("--base_scene", type=str, required=True, help="invariant scene caption as JSON string")
331
+ parser.add_argument("--color_temperature_list", type=str, required=True, help="color_temperature values as JSON string")
332
+ args = parser.parse_args()
333
+ main(args.config, args.base_scene, args.color_temperature_list)
334
+
335
+ # usage example
336
+ # python inference_color_temperature.py --config configs/inference_genphoto/adv3_256_384_genphoto_relora_color_temperature.yaml --base_scene "A beautiful blue sky with a mountain range in the background." --color_temperature_list "[2455.0, 4155.0, 5555.0, 6555.0, 5855.0]"
337
+
338
+
inference_focal_length.py CHANGED
@@ -1,3 +1,335 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c41bc79a24be2dce1457e285e6fcd5cb3396b677bae30ae010e3f23ae993817c
3
- size 15177
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import imageio
3
+ import os
4
+ import torch
5
+ import logging
6
+ import argparse
7
+ import json
8
+ import numpy as np
9
+ import torch.nn.functional as F
10
+ from pathlib import Path
11
+ from omegaconf import OmegaConf
12
+ from torch.utils.data import Dataset
13
+ from transformers import CLIPTextModel, CLIPTokenizer
14
+ from diffusers import AutoencoderKL, DDIMScheduler
15
+ from einops import rearrange
16
+
17
+ from genphoto.pipelines.pipeline_animation import GenPhotoPipeline
18
+ from genphoto.models.unet import UNet3DConditionModelCameraCond
19
+ from genphoto.models.camera_adaptor import CameraCameraEncoder, CameraAdaptor
20
+ from genphoto.utils.util import save_videos_grid
21
+
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+
27
+ from huggingface_hub import hf_hub_download
28
+
29
+
30
+
31
+ def create_focal_length_embedding(focal_length_values, target_height, target_width, base_focal_length=24.0, sensor_height=24.0, sensor_width=36.0):
32
+ device = 'cpu'
33
+ focal_length_values = focal_length_values.to(device)
34
+ f = focal_length_values.shape[0] # Number of frames
35
+
36
+
37
+ # Convert constants to tensors to perform operations with focal_length_values
38
+ sensor_width = torch.tensor(sensor_width, device=device)
39
+ sensor_height = torch.tensor(sensor_height, device=device)
40
+ base_focal_length = torch.tensor(base_focal_length, device=device)
41
+
42
+ # Calculate the FOV for the base focal length (min_focal_length)
43
+ base_fov_x = 2.0 * torch.atan(sensor_width * 0.5 / base_focal_length)
44
+ base_fov_y = 2.0 * torch.atan(sensor_height * 0.5 / base_focal_length)
45
+
46
+ # Calculate the FOV for each focal length in focal_length_values
47
+ target_fov_x = 2.0 * torch.atan(sensor_width * 0.5 / focal_length_values)
48
+ target_fov_y = 2.0 * torch.atan(sensor_height * 0.5 / focal_length_values)
49
+
50
+ # Calculate crop ratio: how much of the image is cropped at the current focal length
51
+ crop_ratio_xs = target_fov_x / base_fov_x # Crop ratio for horizontal axis
52
+ crop_ratio_ys = target_fov_y / base_fov_y # Crop ratio for vertical axis
53
+
54
+ # Get the center of the image
55
+ center_h, center_w = target_height // 2, target_width // 2
56
+
57
+ # Initialize a mask tensor with zeros on CPU
58
+ focal_length_embedding = torch.zeros((f, 3, target_height, target_width), dtype=torch.float32) # Shape [f, 3, H, W]
59
+
60
+ # Fill the center region with 1 based on the calculated crop dimensions
61
+ for i in range(f):
62
+ # Crop dimensions calculated using rounded float values
63
+ crop_h = torch.round(crop_ratio_ys[i] * target_height).int().item() # Rounded cropped height for the current frame
64
+ # print('crop_h', crop_h)
65
+ crop_w = torch.round(crop_ratio_xs[i] * target_width).int().item() # Rounded cropped width for the current frame
66
+
67
+ # Ensure the cropped dimensions are within valid bounds
68
+ crop_h = max(1, min(target_height, crop_h))
69
+ crop_w = max(1, min(target_width, crop_w))
70
+
71
+ # Set the center region of the focal_length embedding to 1 for the current frame
72
+ focal_length_embedding[i, :,
73
+ center_h - crop_h // 2: center_h + crop_h // 2,
74
+ center_w - crop_w // 2: center_w + crop_w // 2] = 1.0
75
+
76
+ return focal_length_embedding
77
+
78
+
79
+ class Camera_Embedding(Dataset):
80
+ def __init__(self, focal_length_values, tokenizer, text_encoder, device, sample_size=[256, 384]):
81
+ self.focal_length_values = focal_length_values.to(device)
82
+ self.tokenizer = tokenizer
83
+ self.text_encoder = text_encoder
84
+ self.device = device
85
+ self.sample_size = sample_size
86
+
87
+ def load(self):
88
+
89
+ if len(self.focal_length_values) != 5:
90
+ raise ValueError("Expected 5 focal_length values")
91
+
92
+ # Generate prompts for each focal length value and append focal_length information to caption
93
+ prompts = []
94
+ for fl in self.focal_length_values:
95
+ prompt = f"<focal length: {fl.item()}>"
96
+ prompts.append(prompt)
97
+
98
+
99
+ # Tokenize prompts and encode to get embeddings
100
+ with torch.no_grad():
101
+ prompt_ids = self.tokenizer(
102
+ prompts, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
103
+ ).input_ids.to(self.device)
104
+
105
+ encoder_hidden_states = self.text_encoder(input_ids=prompt_ids).last_hidden_state # Shape: (f, sequence_length, hidden_size)
106
+
107
+
108
+ # Calculate differences between consecutive embeddings (ignoring sequence_length)
109
+ differences = []
110
+ for i in range(1, encoder_hidden_states.size(0)):
111
+ diff = encoder_hidden_states[i] - encoder_hidden_states[i - 1]
112
+ diff = diff.unsqueeze(0)
113
+ differences.append(diff)
114
+
115
+ # Add the difference between the last and the first embedding
116
+ final_diff = encoder_hidden_states[-1] - encoder_hidden_states[0]
117
+ final_diff = final_diff.unsqueeze(0)
118
+ differences.append(final_diff)
119
+
120
+ # Concatenate differences along the batch dimension (f-1)
121
+ concatenated_differences = torch.cat(differences, dim=0)
122
+ frame = concatenated_differences.size(0)
123
+ concatenated_differences = torch.cat(differences, dim=0)
124
+
125
+ pad_length = 128 - concatenated_differences.size(1)
126
+ if pad_length > 0:
127
+ # Pad along the second dimension (77 -> 128), pad only on the right side
128
+ concatenated_differences_padded = F.pad(concatenated_differences, (0, 0, 0, pad_length))
129
+
130
+
131
+ ccl_embedding = concatenated_differences_padded.reshape(frame, self.sample_size[0], self.sample_size[1])
132
+ ccl_embedding = ccl_embedding.unsqueeze(1)
133
+ ccl_embedding = ccl_embedding.expand(-1, 3, -1, -1)
134
+ ccl_embedding = ccl_embedding.to(self.device)
135
+ focal_length_embedding = create_focal_length_embedding(self.focal_length_values, self.sample_size[0], self.sample_size[1]).to(self.device)
136
+
137
+ camera_embedding = torch.cat((focal_length_embedding, ccl_embedding), dim=1)
138
+ return camera_embedding
139
+
140
+ #
141
+ # def load_models(cfg):
142
+ #
143
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
144
+ #
145
+ # noise_scheduler = DDIMScheduler(**OmegaConf.to_container(cfg.noise_scheduler_kwargs))
146
+ # vae = AutoencoderKL.from_pretrained(cfg.pretrained_model_path, subfolder="vae").to(device)
147
+ # vae.requires_grad_(False)
148
+ # tokenizer = CLIPTokenizer.from_pretrained(cfg.pretrained_model_path, subfolder="tokenizer")
149
+ # text_encoder = CLIPTextModel.from_pretrained(cfg.pretrained_model_path, subfolder="text_encoder").to(device)
150
+ # text_encoder.requires_grad_(False)
151
+ # unet = UNet3DConditionModelCameraCond.from_pretrained_2d(
152
+ # cfg.pretrained_model_path,
153
+ # subfolder=cfg.unet_subfolder,
154
+ # unet_additional_kwargs=cfg.unet_additional_kwargs
155
+ # ).to(device)
156
+ # unet.requires_grad_(False)
157
+ #
158
+ # camera_encoder = CameraCameraEncoder(**cfg.camera_encoder_kwargs).to(device)
159
+ # camera_encoder.requires_grad_(False)
160
+ # camera_adaptor = CameraAdaptor(unet, camera_encoder)
161
+ # camera_adaptor.requires_grad_(False)
162
+ # camera_adaptor.to(device)
163
+ #
164
+ # logger.info("Setting the attention processors")
165
+ # unet.set_all_attn_processor(
166
+ # add_spatial_lora=cfg.lora_ckpt is not None,
167
+ # add_motion_lora=cfg.motion_lora_rank > 0,
168
+ # lora_kwargs={"lora_rank": cfg.lora_rank, "lora_scale": cfg.lora_scale},
169
+ # motion_lora_kwargs={"lora_rank": cfg.motion_lora_rank, "lora_scale": cfg.motion_lora_scale},
170
+ # **cfg.attention_processor_kwargs
171
+ # )
172
+ #
173
+ # if cfg.lora_ckpt is not None:
174
+ # print(f"Loading the lora checkpoint from {cfg.lora_ckpt}")
175
+ # lora_checkpoints = torch.load(cfg.lora_ckpt, map_location=unet.device)
176
+ # if 'lora_state_dict' in lora_checkpoints.keys():
177
+ # lora_checkpoints = lora_checkpoints['lora_state_dict']
178
+ # _, lora_u = unet.load_state_dict(lora_checkpoints, strict=False)
179
+ # assert len(lora_u) == 0
180
+ # print(f'Loading done')
181
+ #
182
+ # if cfg.motion_module_ckpt is not None:
183
+ # print(f"Loading the motion module checkpoint from {cfg.motion_module_ckpt}")
184
+ # mm_checkpoints = torch.load(cfg.motion_module_ckpt, map_location=unet.device)
185
+ # _, mm_u = unet.load_state_dict(mm_checkpoints, strict=False)
186
+ # assert len(mm_u) == 0
187
+ # print("Loading done")
188
+ #
189
+ # if cfg.camera_adaptor_ckpt is not None:
190
+ # logger.info(f"Loading camera adaptor from {cfg.camera_adaptor_ckpt}")
191
+ # camera_adaptor_checkpoint = torch.load(cfg.camera_adaptor_ckpt, map_location=device)
192
+ # camera_encoder_state_dict = camera_adaptor_checkpoint['camera_encoder_state_dict']
193
+ # attention_processor_state_dict = camera_adaptor_checkpoint['attention_processor_state_dict']
194
+ # camera_enc_m, camera_enc_u = camera_adaptor.camera_encoder.load_state_dict(camera_encoder_state_dict, strict=False)
195
+ #
196
+ # assert len(camera_enc_m) == 0 and len(camera_enc_u) == 0
197
+ # _, attention_processor_u = camera_adaptor.unet.load_state_dict(attention_processor_state_dict, strict=False)
198
+ # assert len(attention_processor_u) == 0
199
+ #
200
+ # logger.info("Camera Adaptor loading done")
201
+ # else:
202
+ # logger.info("No Camera Adaptor checkpoint used")
203
+ #
204
+ # pipeline = GenPhotoPipeline(
205
+ # vae=vae,
206
+ # text_encoder=text_encoder,
207
+ # tokenizer=tokenizer,
208
+ # unet=unet,
209
+ # scheduler=noise_scheduler,
210
+ # camera_encoder=camera_encoder
211
+ # ).to(device)
212
+ # pipeline.enable_vae_slicing()
213
+ #
214
+ # return pipeline, device
215
+
216
+
217
+ def load_models(cfg):
218
+ device = "cuda" if torch.cuda.is_available() else "cpu"
219
+
220
+ pretrained_model_path = hf_hub_download("pandaphd/generative_photography", "stable-diffusion-v1-5/")
221
+ lora_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/RealEstate10K_LoRA.ckpt")
222
+ motion_module_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/v3_sd15_mm.ckpt")
223
+ camera_adaptor_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/checkpoint-focal_length.ckpt")
224
+
225
+ noise_scheduler = DDIMScheduler(**OmegaConf.to_container(cfg.noise_scheduler_kwargs))
226
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
227
+ vae.requires_grad_(False)
228
+
229
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
230
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device)
231
+ text_encoder.requires_grad_(False)
232
+
233
+ unet = UNet3DConditionModelCameraCond.from_pretrained_2d(
234
+ pretrained_model_path,
235
+ subfolder=cfg.unet_subfolder,
236
+ unet_additional_kwargs=cfg.unet_additional_kwargs
237
+ ).to(device)
238
+ unet.requires_grad_(False)
239
+
240
+ camera_encoder = CameraCameraEncoder(**cfg.camera_encoder_kwargs).to(device)
241
+ camera_encoder.requires_grad_(False)
242
+ camera_adaptor = CameraAdaptor(unet, camera_encoder)
243
+ camera_adaptor.requires_grad_(False)
244
+ camera_adaptor.to(device)
245
+
246
+ unet.set_all_attn_processor(
247
+ add_spatial_lora=cfg.lora_ckpt is not None,
248
+ add_motion_lora=cfg.motion_lora_rank > 0,
249
+ lora_kwargs={"lora_rank": cfg.lora_rank, "lora_scale": cfg.lora_scale},
250
+ motion_lora_kwargs={"lora_rank": cfg.motion_lora_rank, "lora_scale": cfg.motion_lora_scale},
251
+ **cfg.attention_processor_kwargs
252
+ )
253
+
254
+ if cfg.lora_ckpt is not None:
255
+ lora_checkpoints = torch.load(lora_ckpt_path, map_location=unet.device)
256
+ if 'lora_state_dict' in lora_checkpoints.keys():
257
+ lora_checkpoints = lora_checkpoints['lora_state_dict']
258
+ _, lora_u = unet.load_state_dict(lora_checkpoints, strict=False)
259
+ assert len(lora_u) == 0
260
+
261
+ if cfg.motion_module_ckpt is not None:
262
+ mm_checkpoints = torch.load(motion_module_ckpt_path, map_location=unet.device)
263
+ _, mm_u = unet.load_state_dict(mm_checkpoints, strict=False)
264
+ assert len(mm_u) == 0
265
+
266
+ if cfg.camera_adaptor_ckpt is not None:
267
+ camera_adaptor_checkpoint = torch.load(camera_adaptor_ckpt_path, map_location=device)
268
+ camera_encoder_state_dict = camera_adaptor_checkpoint['camera_encoder_state_dict']
269
+ attention_processor_state_dict = camera_adaptor_checkpoint['attention_processor_state_dict']
270
+ camera_enc_m, camera_enc_u = camera_adaptor.camera_encoder.load_state_dict(camera_encoder_state_dict, strict=False)
271
+ assert len(camera_enc_m) == 0 and len(camera_enc_u) == 0
272
+ _, attention_processor_u = camera_adaptor.unet.load_state_dict(attention_processor_state_dict, strict=False)
273
+ assert len(attention_processor_u) == 0
274
+
275
+ pipeline = GenPhotoPipeline(
276
+ vae=vae,
277
+ text_encoder=text_encoder,
278
+ tokenizer=tokenizer,
279
+ unet=unet,
280
+ scheduler=noise_scheduler,
281
+ camera_encoder=camera_encoder
282
+ ).to(device)
283
+
284
+ pipeline.enable_vae_slicing()
285
+ return pipeline, device
286
+
287
+ def run_inference(pipeline, tokenizer, text_encoder, base_scene, focal_length_list, device, video_length=5, height=256, width=384):
288
+
289
+ focal_length_values = json.loads(focal_length_list)
290
+ focal_length_values = torch.tensor(focal_length_values).unsqueeze(1)
291
+
292
+ # Ensure camera_embedding is on the correct device
293
+ camera_embedding = Camera_Embedding(focal_length_values, tokenizer, text_encoder, device).load()
294
+ camera_embedding = rearrange(camera_embedding.unsqueeze(0), "b f c h w -> b c f h w")
295
+
296
+ with torch.no_grad():
297
+ sample = pipeline(
298
+ prompt=base_scene,
299
+ camera_embedding=camera_embedding,
300
+ video_length=video_length,
301
+ height=height,
302
+ width=width,
303
+ num_inference_steps=25,
304
+ guidance_scale=8.0
305
+ ).videos[0].cpu()
306
+
307
+ temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
308
+ save_videos_grid(sample[None], temporal_video_path, rescale=False)
309
+
310
+
311
+ return temporal_video_path
312
+
313
+
314
+ def main(config_path, base_scene, focal_length_list):
315
+ torch.manual_seed(42)
316
+ cfg = OmegaConf.load(config_path)
317
+ logger.info("Loading models...")
318
+ pipeline, device = load_models(cfg)
319
+ logger.info("Starting inference...")
320
+
321
+ video_path = run_inference(pipeline, pipeline.tokenizer, pipeline.text_encoder, base_scene, focal_length_list, device)
322
+ logger.info(f"Video saved to {video_path}")
323
+
324
+
325
+ if __name__ == "__main__":
326
+ parser = argparse.ArgumentParser()
327
+ parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
328
+ parser.add_argument("--base_scene", type=str, required=True, help="invariant scene caption as JSON string")
329
+ parser.add_argument("--focal_length_list", type=str, required=True, help="focal_length values as JSON string")
330
+ args = parser.parse_args()
331
+ main(args.config, args.base_scene, args.focal_length_list)
332
+
333
+ # usage example
334
+ # python inference_focal_length.py --config configs/inference_genphoto/adv3_256_384_genphoto_relora_focal_length.yaml --base_scene "A cozy living room with a large, comfy sofa and a coffee table." --focal_length_list "[25.0, 35.0, 45.0, 55.0, 65.0]"
335
+
inference_shutter_speed.py CHANGED
@@ -1,3 +1,322 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:12eb2507454a07a5e565233b738991782d191e932470176783be93773fb0f209
3
- size 13888
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import imageio
3
+ import os
4
+ import torch
5
+ import logging
6
+ import argparse
7
+ import json
8
+ import numpy as np
9
+ import torch.nn.functional as F
10
+ from pathlib import Path
11
+ from omegaconf import OmegaConf
12
+ from torch.utils.data import Dataset
13
+ from transformers import CLIPTextModel, CLIPTokenizer
14
+ from diffusers import AutoencoderKL, DDIMScheduler
15
+ from einops import rearrange
16
+
17
+ from genphoto.pipelines.pipeline_animation import GenPhotoPipeline
18
+ from genphoto.models.unet import UNet3DConditionModelCameraCond
19
+ from genphoto.models.camera_adaptor import CameraCameraEncoder, CameraAdaptor
20
+ from genphoto.utils.util import save_videos_grid
21
+
22
+ logging.basicConfig(level=logging.INFO)
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ from huggingface_hub import hf_hub_download
27
+
28
+
29
+
30
+ def create_shutter_speed_embedding(shutter_speed_values, target_height, target_width, base_exposure=0.5):
31
+ """
32
+ Create a shutter_speed (Exposure Value or shutter speed) embedding tensor using a constant fwc value.
33
+ Args:
34
+ - shutter_speed_values: Tensor of shape [f, 1] containing shutter_speed values for each frame.
35
+ - H: Height of the image.
36
+ - W: Width of the image.
37
+ - base_exposure: A base exposure value to normalize brightness (defaults to 0.18 as a common base exposure level).
38
+
39
+ Returns:
40
+ - shutter_speed_embedding: Tensor of shape [f, 1, H, W] where each pixel is scaled based on the shutter_speed values.
41
+ """
42
+ f = shutter_speed_values.shape[0]
43
+
44
+ # Set a constant full well capacity (fwc)
45
+ fwc = 32000 # Constant value for full well capacity
46
+
47
+ # Calculate scale based on EV and sensor full well capacity (fwc)
48
+ scales = (shutter_speed_values / base_exposure) * (fwc / (fwc + 0.0001))
49
+
50
+ # Reshape and expand to match image dimensions
51
+ scales = scales.unsqueeze(2).unsqueeze(3).expand(f, 3, target_height, target_width)
52
+
53
+ # Use scales to create the final shutter_speed embedding
54
+ shutter_speed_embedding = scales # Shape [f, 3, H, W]
55
+
56
+ return shutter_speed_embedding
57
+
58
+
59
+
60
+ class Camera_Embedding(Dataset):
61
+ def __init__(self, shutter_speed_values, tokenizer, text_encoder, device, sample_size=[256, 384]):
62
+ self.shutter_speed_values = shutter_speed_values.to(device)
63
+ self.tokenizer = tokenizer
64
+ self.text_encoder = text_encoder
65
+ self.device = device
66
+ self.sample_size = sample_size
67
+
68
+ def load(self):
69
+
70
+ if len(self.shutter_speed_values) != 5:
71
+ raise ValueError("Expected 5 shutter_speed values")
72
+
73
+ # Generate prompts for each shutter_speed value and append shutter_speed information to caption
74
+ prompts = []
75
+ for ss in self.shutter_speed_values:
76
+ prompt = f"<exposure: {ss.item()}>"
77
+ prompts.append(prompt)
78
+
79
+ # Tokenize prompts and encode to get embeddings
80
+ with torch.no_grad():
81
+ prompt_ids = self.tokenizer(
82
+ prompts, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt"
83
+ ).input_ids.to(self.device)
84
+
85
+ encoder_hidden_states = self.text_encoder(input_ids=prompt_ids).last_hidden_state # Shape: (f, sequence_length, hidden_size)
86
+
87
+
88
+ # Calculate differences between consecutive embeddings (ignoring sequence_length)
89
+ differences = []
90
+ for i in range(1, encoder_hidden_states.size(0)):
91
+ diff = encoder_hidden_states[i] - encoder_hidden_states[i - 1]
92
+ diff = diff.unsqueeze(0)
93
+ differences.append(diff)
94
+
95
+ # Add the difference between the last and the first embedding
96
+ final_diff = encoder_hidden_states[-1] - encoder_hidden_states[0]
97
+ final_diff = final_diff.unsqueeze(0)
98
+ differences.append(final_diff)
99
+
100
+ # Concatenate differences along the batch dimension (f-1)
101
+ concatenated_differences = torch.cat(differences, dim=0)
102
+ frame = concatenated_differences.size(0)
103
+
104
+ concatenated_differences = torch.cat(differences, dim=0)
105
+
106
+ pad_length = 128 - concatenated_differences.size(1)
107
+ print('pad_length', pad_length)
108
+ if pad_length > 0:
109
+
110
+ concatenated_differences_padded = F.pad(concatenated_differences, (0, 0, 0, pad_length))
111
+
112
+
113
+ ccl_embedding = concatenated_differences_padded.reshape(frame, self.sample_size[0], self.sample_size[1])
114
+ ccl_embedding = ccl_embedding.unsqueeze(1)
115
+ ccl_embedding = ccl_embedding.expand(-1, 3, -1, -1)
116
+ ccl_embedding = ccl_embedding.to(self.device)
117
+ shutter_speed_embedding = create_shutter_speed_embedding(self.shutter_speed_values, self.sample_size[0], self.sample_size[1]).to(self.device)
118
+ camera_embedding = torch.cat((shutter_speed_embedding, ccl_embedding), dim=1)
119
+ return camera_embedding
120
+
121
+
122
+ # def load_models(cfg):
123
+ #
124
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
125
+ #
126
+ # noise_scheduler = DDIMScheduler(**OmegaConf.to_container(cfg.noise_scheduler_kwargs))
127
+ # vae = AutoencoderKL.from_pretrained(cfg.pretrained_model_path, subfolder="vae").to(device)
128
+ # vae.requires_grad_(False)
129
+ # tokenizer = CLIPTokenizer.from_pretrained(cfg.pretrained_model_path, subfolder="tokenizer")
130
+ # text_encoder = CLIPTextModel.from_pretrained(cfg.pretrained_model_path, subfolder="text_encoder").to(device)
131
+ # text_encoder.requires_grad_(False)
132
+ #
133
+ # unet = UNet3DConditionModelCameraCond.from_pretrained_2d(
134
+ # cfg.pretrained_model_path,
135
+ # subfolder=cfg.unet_subfolder,
136
+ # unet_additional_kwargs=cfg.unet_additional_kwargs
137
+ # ).to(device)
138
+ # unet.requires_grad_(False)
139
+ #
140
+ #
141
+ # camera_encoder = CameraCameraEncoder(**cfg.camera_encoder_kwargs).to(device)
142
+ # camera_encoder.requires_grad_(False)
143
+ # camera_adaptor = CameraAdaptor(unet, camera_encoder)
144
+ # camera_adaptor.requires_grad_(False)
145
+ # camera_adaptor.to(device)
146
+ #
147
+ # logger.info("Setting the attention processors")
148
+ # unet.set_all_attn_processor(
149
+ # add_spatial_lora=cfg.lora_ckpt is not None,
150
+ # add_motion_lora=cfg.motion_lora_rank > 0,
151
+ # lora_kwargs={"lora_rank": cfg.lora_rank, "lora_scale": cfg.lora_scale},
152
+ # motion_lora_kwargs={"lora_rank": cfg.motion_lora_rank, "lora_scale": cfg.motion_lora_scale},
153
+ # **cfg.attention_processor_kwargs
154
+ # )
155
+ #
156
+ # if cfg.lora_ckpt is not None:
157
+ # print(f"Loading the lora checkpoint from {cfg.lora_ckpt}")
158
+ # lora_checkpoints = torch.load(cfg.lora_ckpt, map_location=unet.device)
159
+ # if 'lora_state_dict' in lora_checkpoints.keys():
160
+ # lora_checkpoints = lora_checkpoints['lora_state_dict']
161
+ # _, lora_u = unet.load_state_dict(lora_checkpoints, strict=False)
162
+ # assert len(lora_u) == 0
163
+ # print(f'Loading done')
164
+ #
165
+ # if cfg.motion_module_ckpt is not None:
166
+ # print(f"Loading the motion module checkpoint from {cfg.motion_module_ckpt}")
167
+ # mm_checkpoints = torch.load(cfg.motion_module_ckpt, map_location=unet.device)
168
+ # _, mm_u = unet.load_state_dict(mm_checkpoints, strict=False)
169
+ # assert len(mm_u) == 0
170
+ # print("Loading done")
171
+ #
172
+ #
173
+ # if cfg.camera_adaptor_ckpt is not None:
174
+ # logger.info(f"Loading camera adaptor from {cfg.camera_adaptor_ckpt}")
175
+ # camera_adaptor_checkpoint = torch.load(cfg.camera_adaptor_ckpt, map_location=device)
176
+ #
177
+ # camera_encoder_state_dict = camera_adaptor_checkpoint['camera_encoder_state_dict']
178
+ # attention_processor_state_dict = camera_adaptor_checkpoint['attention_processor_state_dict']
179
+ #
180
+ # camera_enc_m, camera_enc_u = camera_adaptor.camera_encoder.load_state_dict(camera_encoder_state_dict, strict=False)
181
+ #
182
+ # assert len(camera_enc_m) == 0 and len(camera_enc_u) == 0
183
+ # _, attention_processor_u = camera_adaptor.unet.load_state_dict(attention_processor_state_dict, strict=False)
184
+ # assert len(attention_processor_u) == 0
185
+ #
186
+ # logger.info("Camera Adaptor loading done")
187
+ # else:
188
+ # logger.info("No Camera Adaptor checkpoint used")
189
+ #
190
+ # pipeline = GenPhotoPipeline(
191
+ # vae=vae,
192
+ # text_encoder=text_encoder,
193
+ # tokenizer=tokenizer,
194
+ # unet=unet,
195
+ # scheduler=noise_scheduler,
196
+ # camera_encoder=camera_encoder
197
+ # ).to(device)
198
+ # pipeline.enable_vae_slicing()
199
+ #
200
+ # return pipeline, device
201
+
202
+ def load_models(cfg):
203
+ device = "cuda" if torch.cuda.is_available() else "cpu"
204
+
205
+ pretrained_model_path = hf_hub_download("pandaphd/generative_photography", "stable-diffusion-v1-5/")
206
+ lora_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/RealEstate10K_LoRA.ckpt")
207
+ motion_module_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/v3_sd15_mm.ckpt")
208
+ camera_adaptor_ckpt_path = hf_hub_download("pandaphd/generative_photography", "weights/checkpoint-shutter_speed.ckpt")
209
+
210
+ noise_scheduler = DDIMScheduler(**OmegaConf.to_container(cfg.noise_scheduler_kwargs))
211
+ vae = AutoencoderKL.from_pretrained(pretrained_model_path, subfolder="vae").to(device)
212
+ vae.requires_grad_(False)
213
+
214
+ tokenizer = CLIPTokenizer.from_pretrained(pretrained_model_path, subfolder="tokenizer")
215
+ text_encoder = CLIPTextModel.from_pretrained(pretrained_model_path, subfolder="text_encoder").to(device)
216
+ text_encoder.requires_grad_(False)
217
+
218
+ unet = UNet3DConditionModelCameraCond.from_pretrained_2d(
219
+ pretrained_model_path,
220
+ subfolder=cfg.unet_subfolder,
221
+ unet_additional_kwargs=cfg.unet_additional_kwargs
222
+ ).to(device)
223
+ unet.requires_grad_(False)
224
+
225
+ camera_encoder = CameraCameraEncoder(**cfg.camera_encoder_kwargs).to(device)
226
+ camera_encoder.requires_grad_(False)
227
+ camera_adaptor = CameraAdaptor(unet, camera_encoder)
228
+ camera_adaptor.requires_grad_(False)
229
+ camera_adaptor.to(device)
230
+
231
+ unet.set_all_attn_processor(
232
+ add_spatial_lora=cfg.lora_ckpt is not None,
233
+ add_motion_lora=cfg.motion_lora_rank > 0,
234
+ lora_kwargs={"lora_rank": cfg.lora_rank, "lora_scale": cfg.lora_scale},
235
+ motion_lora_kwargs={"lora_rank": cfg.motion_lora_rank, "lora_scale": cfg.motion_lora_scale},
236
+ **cfg.attention_processor_kwargs
237
+ )
238
+
239
+ if cfg.lora_ckpt is not None:
240
+ lora_checkpoints = torch.load(lora_ckpt_path, map_location=unet.device)
241
+ if 'lora_state_dict' in lora_checkpoints.keys():
242
+ lora_checkpoints = lora_checkpoints['lora_state_dict']
243
+ _, lora_u = unet.load_state_dict(lora_checkpoints, strict=False)
244
+ assert len(lora_u) == 0
245
+
246
+ if cfg.motion_module_ckpt is not None:
247
+ mm_checkpoints = torch.load(motion_module_ckpt_path, map_location=unet.device)
248
+ _, mm_u = unet.load_state_dict(mm_checkpoints, strict=False)
249
+ assert len(mm_u) == 0
250
+
251
+ if cfg.camera_adaptor_ckpt is not None:
252
+ camera_adaptor_checkpoint = torch.load(camera_adaptor_ckpt_path, map_location=device)
253
+ camera_encoder_state_dict = camera_adaptor_checkpoint['camera_encoder_state_dict']
254
+ attention_processor_state_dict = camera_adaptor_checkpoint['attention_processor_state_dict']
255
+ camera_enc_m, camera_enc_u = camera_adaptor.camera_encoder.load_state_dict(camera_encoder_state_dict, strict=False)
256
+ assert len(camera_enc_m) == 0 and len(camera_enc_u) == 0
257
+ _, attention_processor_u = camera_adaptor.unet.load_state_dict(attention_processor_state_dict, strict=False)
258
+ assert len(attention_processor_u) == 0
259
+
260
+ pipeline = GenPhotoPipeline(
261
+ vae=vae,
262
+ text_encoder=text_encoder,
263
+ tokenizer=tokenizer,
264
+ unet=unet,
265
+ scheduler=noise_scheduler,
266
+ camera_encoder=camera_encoder
267
+ ).to(device)
268
+
269
+ pipeline.enable_vae_slicing()
270
+ return pipeline, device
271
+
272
+
273
+
274
+ def run_inference(pipeline, tokenizer, text_encoder, base_scene, shutter_speed_list, device, video_length=5, height=256, width=384):
275
+
276
+ shutter_speed_values = json.loads(shutter_speed_list)
277
+ shutter_speed_values = torch.tensor(shutter_speed_values).unsqueeze(1)
278
+
279
+ # Ensure camera_embedding is on the correct device
280
+ camera_embedding = Camera_Embedding(shutter_speed_values, tokenizer, text_encoder, device).load()
281
+ camera_embedding = rearrange(camera_embedding.unsqueeze(0), "b f c h w -> b c f h w")
282
+
283
+ with torch.no_grad():
284
+ sample = pipeline(
285
+ prompt=base_scene,
286
+ camera_embedding=camera_embedding,
287
+ video_length=video_length,
288
+ height=height,
289
+ width=width,
290
+ num_inference_steps=25,
291
+ guidance_scale=8.0
292
+ ).videos[0].cpu()
293
+
294
+ temporal_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
295
+ save_videos_grid(sample[None], temporal_video_path, rescale=False)
296
+
297
+
298
+ return temporal_video_path
299
+
300
+
301
+ def main(config_path, base_scene, shutter_speed_list):
302
+ torch.manual_seed(42)
303
+ cfg = OmegaConf.load(config_path)
304
+ logger.info("Loading models...")
305
+ pipeline, device = load_models(cfg)
306
+ logger.info("Starting inference...")
307
+
308
+ video_path = run_inference(pipeline, pipeline.tokenizer, pipeline.text_encoder, base_scene, shutter_speed_list, device)
309
+ logger.info(f"Video saved to {video_path}")
310
+
311
+
312
+ if __name__ == "__main__":
313
+ parser = argparse.ArgumentParser()
314
+ parser.add_argument("--config", type=str, required=True, help="Path to YAML configuration file")
315
+ parser.add_argument("--base_scene", type=str, required=True, help="invariant scene caption as JSON string")
316
+ parser.add_argument("--shutter_speed_list", type=str, required=True, help="shutter_speed values as JSON string")
317
+ args = parser.parse_args()
318
+ main(args.config, args.base_scene, args.shutter_speed_list)
319
+
320
+ # usage example
321
+ # python inference_shutter_speed.py --config configs/inference_genphoto/adv3_256_384_genphoto_relora_shutter_speed.yaml --base_scene "A modern bathroom with a mirror and soft lighting." --shutter_speed_list "[0.1, 0.3, 0.52, 0.7, 0.8]"
322
+
requirements.txt CHANGED
@@ -1,3 +1,19 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1766bd0739223e95b2fde76b862d853da41c15b0d97273e7e90f4cd4a4d77a60
3
- size 290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+ torch==2.1.1
3
+ torchvision==0.16.1
4
+ torchaudio==2.1.1
5
+ diffusers==0.24.0
6
+ imageio==2.36.0
7
+ imageio-ffmpeg
8
+ transformers
9
+ accelerate
10
+ opencv-python
11
+ gdown
12
+ einops
13
+ decord
14
+ omegaconf
15
+ safetensors
16
+ gradio
17
+ wandb
18
+ triton
19
+ huggingface_hub