tennant commited on
Commit
af7c0ce
·
1 Parent(s): c73cb1d
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +58 -0
  2. README.md +1 -1
  3. configs/autoregressive_config.yaml +77 -0
  4. configs/onenode_config.yaml +12 -0
  5. configs/tokenizer_config.yaml +64 -0
  6. demo.py +170 -0
  7. examples/city.jpg +0 -0
  8. examples/food.jpg +0 -0
  9. examples/highland.webp +0 -0
  10. fid_stats/adm_in256_stats.npz +3 -0
  11. gen_demo.py +137 -0
  12. making_cache.py +46 -0
  13. paintmind/__init__.py +2 -0
  14. paintmind/config.py +110 -0
  15. paintmind/engine/gpt_trainer.py +892 -0
  16. paintmind/engine/misc.py +260 -0
  17. paintmind/engine/trainer.py +695 -0
  18. paintmind/engine/util.py +572 -0
  19. paintmind/stage1/__init__.py +0 -0
  20. paintmind/stage1/diffuse_slot.py +808 -0
  21. paintmind/stage1/diffusion/__init__.py +46 -0
  22. paintmind/stage1/diffusion/diffusion_utils.py +88 -0
  23. paintmind/stage1/diffusion/gaussian_diffusion.py +886 -0
  24. paintmind/stage1/diffusion/respace.py +130 -0
  25. paintmind/stage1/diffusion/timestep_sampler.py +150 -0
  26. paintmind/stage1/diffusion_transfomers.py +372 -0
  27. paintmind/stage1/fused_attention.py +94 -0
  28. paintmind/stage1/pos_embed.py +102 -0
  29. paintmind/stage1/quantize.py +93 -0
  30. paintmind/stage1/transport/__init__.py +63 -0
  31. paintmind/stage1/transport/integrators.py +130 -0
  32. paintmind/stage1/transport/path.py +192 -0
  33. paintmind/stage1/transport/transport.py +456 -0
  34. paintmind/stage1/transport/utils.py +29 -0
  35. paintmind/stage1/vision_transformers.py +267 -0
  36. paintmind/stage2/__init__.py +0 -0
  37. paintmind/stage2/causaldit.py +422 -0
  38. paintmind/stage2/diffloss.py +314 -0
  39. paintmind/stage2/generate.py +127 -0
  40. paintmind/stage2/gpt.py +451 -0
  41. paintmind/utils/__init__.py +0 -0
  42. paintmind/utils/datasets.py +77 -0
  43. paintmind/utils/device_utils.py +20 -0
  44. paintmind/utils/logger.py +170 -0
  45. paintmind/utils/lr_scheduler.py +15 -0
  46. paintmind/utils/transform.py +35 -0
  47. paintmind/version.py +1 -0
  48. requirements.txt +27 -0
  49. submitit_test.py +290 -0
  50. submitit_train.py +148 -0
Dockerfile ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.4.0-runtime-ubuntu22.04
2
+ LABEL maintainer="Bingchen Zhao"
3
+ LABEL repository="Semanticist"
4
+
5
+ ENV DEBIAN_FRONTEND=noninteractive
6
+
7
+ RUN apt-get -y update \
8
+ && apt-get install -y software-properties-common \
9
+ && add-apt-repository ppa:deadsnakes/ppa
10
+
11
+ RUN apt install -y bash \
12
+ build-essential \
13
+ git \
14
+ git-lfs \
15
+ curl \
16
+ ca-certificates \
17
+ libsndfile1-dev \
18
+ libgl1 \
19
+ python3.10 \
20
+ python3.10-dev \
21
+ python3-pip \
22
+ python3.10-venv rsync sudo tmux && \
23
+ rm -rf /var/lib/apt/lists
24
+
25
+ # make sure to use venv
26
+ RUN python3.10 -m venv /opt/venv
27
+ ENV PATH="/opt/venv/bin:$PATH"
28
+
29
+ # pre-install the heavy dependencies (these can later be overridden by the deps from setup.py)
30
+ RUN python3.10 -m pip install --no-cache-dir --upgrade pip uv==0.1.11 && \
31
+ python3.10 -m uv pip install --no-cache-dir \
32
+ torch \
33
+ torchvision \
34
+ torchaudio \
35
+ invisible_watermark && \
36
+ python3.10 -m pip install --no-cache-dir \
37
+ accelerate \
38
+ datasets \
39
+ hf-doc-builder \
40
+ huggingface-hub \
41
+ hf_transfer \
42
+ Jinja2 \
43
+ librosa \
44
+ numpy==1.26.4 \
45
+ scipy \
46
+ tensorboard \
47
+ transformers \
48
+ pytorch-lightning matplotlib \
49
+ hf_transfer
50
+
51
+ # start Semanticist part
52
+ COPY . /work/Semanticist
53
+ WORKDIR /work/Semanticist
54
+ RUN ls && python3.10 -m pip install -r req_min.txt && \
55
+ python3.10 -m pip install git+https://github.com/cocodataset/panopticapi.git
56
+
57
+ CMD ["/bin/bash"]
58
+ # docker run -it --rm --runtime=nvidia --gpus all xx/xx:xx
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: gray
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.20.1
8
- app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
 
5
  colorTo: red
6
  sdk: gradio
7
  sdk_version: 5.20.1
8
+ app_file: demo.py
9
  pinned: false
10
  license: mit
11
  ---
configs/autoregressive_config.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ trainer:
2
+ target: paintmind.engine.gpt_trainer.GPTTrainer
3
+ params:
4
+ num_epoch: 400
5
+ blr: 1e-4
6
+ cosine_lr: False
7
+ warmup_epochs: 100
8
+ batch_size: 32
9
+ num_workers: 8
10
+ pin_memory: True
11
+ grad_accum_steps: 1
12
+ precision: 'bf16'
13
+ max_grad_norm: 1.0
14
+ enable_ema: True
15
+ save_every: 10000
16
+ sample_every: 5000
17
+ fid_every: 50000
18
+ eval_fid: False
19
+ result_folder: "./output/autoregressive"
20
+ log_dir: "./output/autoregressive/logs"
21
+ ae_cfg: 1.5
22
+ cfg: 1.5
23
+ cfg_schedule: "constant"
24
+ train_num_slots: 32
25
+ test_num_slots: 32
26
+ compile: True
27
+ enable_cache_latents: True
28
+ ae_model:
29
+ target: paintmind.stage1.diffuse_slot.DiffuseSlot
30
+ params:
31
+ encoder: 'vit_base_patch16'
32
+ enc_img_size: 256
33
+ enc_causal: True
34
+ enc_use_mlp: False
35
+ num_slots: 256
36
+ slot_dim: 16
37
+ norm_slots: True
38
+ dit_mask_type: 'replace'
39
+ cond_method: 'token'
40
+ dit_model: 'DiT-XL-2'
41
+ vae: 'xwen99/mar-vae-kl16'
42
+ num_sampling_steps: '250'
43
+ ckpt_path: ./output/tokenizer/models/step250000/custom_checkpoint_1.pkl
44
+
45
+ gpt_model:
46
+ target: GPT-L
47
+ params:
48
+ num_slots: 32
49
+ slot_dim: 16
50
+ num_classes: 1000
51
+ cls_token_num: 1
52
+ resid_dropout_p: 0.1
53
+ ffn_dropout_p: 0.1
54
+ diffloss_d: 12
55
+ diffloss_w: 1536
56
+ num_sampling_steps: '100'
57
+ diffusion_batch_mul: 4
58
+ token_drop_prob: 0
59
+ use_si: True
60
+ cond_method: 'concat'
61
+ decoupled_cfg: False
62
+ ckpt_path: None
63
+
64
+ dataset:
65
+ target: paintmind.utils.datasets.ImageNet
66
+ params:
67
+ root: ./dataset/imagenet/
68
+ split: train
69
+ aug: tencrop_cached
70
+ img_size: 256
71
+
72
+ test_dataset:
73
+ target: paintmind.utils.datasets.ImageNet
74
+ params:
75
+ root: ./dataset/imagenet/
76
+ split: val
77
+ img_size: 256
configs/onenode_config.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ deepspeed_config: {}
3
+ distributed_type: MULTI_GPU
4
+ fsdp_config: {}
5
+ machine_rank: 0
6
+ main_process_ip: null
7
+ main_process_port: null
8
+ main_training_function: main
9
+ mixed_precision: bf16
10
+ num_machines: 1
11
+ num_processes: 8
12
+ use_cpu: false
configs/tokenizer_config.yaml ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ trainer:
2
+ target: paintmind.engine.trainer.DiffusionTrainer
3
+ params:
4
+ num_epoch: 400
5
+ valid_size: 64
6
+ blr: 2.5e-5
7
+ cosine_lr: True
8
+ warmup_epochs: 100
9
+ batch_size: 32
10
+ num_workers: 16
11
+ pin_memory: True
12
+ grad_accum_steps: 1
13
+ precision: 'bf16'
14
+ max_grad_norm: 3.0
15
+ enable_ema: True
16
+ save_every: 10000
17
+ sample_every: 5000
18
+ fid_every: 50000
19
+ result_folder: "./output/tokenizer"
20
+ log_dit: "./output/tokenizer/logs"
21
+ cfg: 3.0
22
+ compile: True
23
+ model:
24
+ target: paintmind.stage1.diffuse_slot.DiffuseSlot
25
+ params:
26
+ encoder: 'vit_base_patch16'
27
+ enc_img_size: 256
28
+ enc_causal: True
29
+ enc_use_mlp: False
30
+ num_slots: 256
31
+ slot_dim: 16
32
+ norm_slots: True
33
+ dit_mask_type: 'replace'
34
+ cond_method: 'token'
35
+ dit_model: 'DiT-XL-2'
36
+ vae: 'xwen99/mar-vae-kl16'
37
+ enable_nest: False
38
+ enable_nest_after: 50
39
+ nest_rho: 0.03
40
+ nest_dist: uniform
41
+ nest_null_prob: 0
42
+ nest_allow_zero: False
43
+ use_repa: True
44
+ repa_encoder: dinov2_vitb
45
+ repa_encoder_depth: 8
46
+ repa_loss_weight: 1.0
47
+ eval_fid: True
48
+ fid_stats: 'fid_stats/adm_in256_stats.npz'
49
+ num_sampling_steps: '250'
50
+ ckpt_path: None
51
+
52
+ dataset:
53
+ target: paintmind.utils.datasets.ImageNet
54
+ params:
55
+ root: ./dataset/imagenet/
56
+ split: train
57
+ img_size: 256
58
+
59
+ test_dataset:
60
+ target: paintmind.utils.datasets.ImageNet
61
+ params:
62
+ root: ./dataset/imagenet/
63
+ split: val
64
+ img_size: 256
demo.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import os
5
+ import torch
6
+ import matplotlib.pyplot as plt
7
+ from omegaconf import OmegaConf
8
+ from huggingface_hub import hf_hub_download
9
+
10
+ from paintmind.engine.util import instantiate_from_config
11
+ from paintmind.stage1.diffuse_slot import DiffuseSlot
12
+
13
+ device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ ckpt_path = hf_hub_download(repo_id='tennant/semanticist', filename='semanticist_tok_XL.pkl')
15
+ config_path = 'configs/tokenizer_config.yaml'
16
+ cfg = OmegaConf.load(config_path)
17
+ ckpt = torch.load(ckpt_path, map_location='cpu')
18
+ from paintmind.utils.datasets import vae_transforms
19
+ from PIL import Image
20
+
21
+ transform = vae_transforms('test')
22
+
23
+
24
+ def norm_ip(img, low, high):
25
+ img.clamp_(min=low, max=high)
26
+ img.sub_(low).div_(max(high - low, 1e-5))
27
+
28
+ def norm_range(t, value_range):
29
+ if value_range is not None:
30
+ norm_ip(t, value_range[0], value_range[1])
31
+ else:
32
+ norm_ip(t, float(t.min()), float(t.max()))
33
+
34
+ from PIL import Image
35
+ def convert_np(img):
36
+ ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\
37
+ .permute(1, 2, 0).to("cpu", torch.uint8).numpy()
38
+ return ndarr
39
+ def convert_PIL(img):
40
+ ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\
41
+ .permute(1, 2, 0).to("cpu", torch.uint8).numpy()
42
+ img = Image.fromarray(ndarr)
43
+ return img
44
+
45
+ ckpt = {k.replace('._orig_mod', ''): v for k, v in ckpt.items()}
46
+
47
+ model = DiffuseSlot(**cfg['trainer']['params']['model']['params'])
48
+ msg = model.load_state_dict(ckpt, strict=False)
49
+ model = model.to(device)
50
+ model = model.eval()
51
+ model.enable_nest = True
52
+
53
+ def viz_diff_slots(model, img, nums, cfg=1.0, return_img=False):
54
+ n_slots_inf = []
55
+ for num_slots_to_inference in nums:
56
+ recon_n = model(
57
+ img, None, sample=True, cfg=cfg,
58
+ inference_with_n_slots=num_slots_to_inference,
59
+ )
60
+ n_slots_inf.append(recon_n)
61
+ return [convert_np(n_slots_inf[i][0]) for i in range(len(n_slots_inf))]
62
+
63
+ # Removed process_image function as its functionality is now in the update_outputs function
64
+
65
+ with gr.Blocks() as demo:
66
+ with gr.Row():
67
+ # First column - Input and configs
68
+ with gr.Column(scale=1):
69
+ gr.Markdown("## Input")
70
+ input_image = gr.Image(label="Upload an image", type="numpy")
71
+
72
+ with gr.Group():
73
+ gr.Markdown("### Configuration")
74
+ show_gallery = gr.Checkbox(label="Show Gallery", value=False)
75
+ # You can add more config options here
76
+ # slider = gr.Slider(minimum=0, maximum=10, value=5, label="Processing Intensity")
77
+ slider = gr.Slider(minimum=0.1, maximum=20.0, value=4.0, label="CFG value")
78
+ labels_input = gr.Textbox(
79
+ label="Gallery Labels (comma-separated)",
80
+ value="1, 4, 16, 64, 256",
81
+ placeholder="Enter comma-separated numbers for the number of slots to use"
82
+ )
83
+
84
+ # Second column - Output (conditionally rendered)
85
+ with gr.Column(scale=1):
86
+ gr.Markdown("## Output")
87
+
88
+ # Container for conditional rendering
89
+ with gr.Group(visible=False) as gallery_container:
90
+ gallery = gr.Gallery(label="Result Gallery", columns=3, height="auto", show_label=True)
91
+
92
+ # Always visible output image
93
+ output_image = gr.Image(label="Processed Image", type="numpy")
94
+
95
+ # Handle form submission
96
+ submit_btn = gr.Button("Process")
97
+
98
+ # Define the processing logic
99
+ def update_outputs(image, show_gallery_value, slider_value, labels_text):
100
+ # Update the visibility of the gallery container
101
+ gallery_container.visible = show_gallery_value
102
+
103
+ try:
104
+ # Parse the labels from the text input
105
+ if labels_text and "," in labels_text:
106
+ labels = [int(label.strip()) for label in labels_text.split(",")]
107
+ else:
108
+ # Default labels if none provided or in wrong format
109
+ labels = [1, 4, 16, 64, 256]
110
+ except:
111
+ labels = [1, 4, 16, 64, 256]
112
+ while len(labels) < 3:
113
+ labels.append(256)
114
+
115
+ # Process the image based on configurations
116
+ if image is None:
117
+ # Return placeholder if no image is uploaded
118
+ placeholder = np.zeros((300, 300, 3), dtype=np.uint8)
119
+ return gallery_container, [], placeholder
120
+ image = Image.fromarray(image)
121
+ img = transform(image)
122
+ img = img.unsqueeze(0).to(device)
123
+ recon = viz_diff_slots(model, img, [256], cfg=slider_value)[0]
124
+
125
+
126
+ if not show_gallery_value:
127
+ # If only the image should be shown, return just the processed image
128
+ return gallery_container, [], recon
129
+ else:
130
+ model_decompose = viz_diff_slots(model, img, labels, cfg=slider_value)
131
+ # Create image variations and pair them with labels
132
+ gallery_images = [
133
+ (image, 'GT'),
134
+ # (np.array(Image.fromarray(image).convert("L").convert("RGB")), labels[1]),
135
+ # (np.array(Image.fromarray(image).rotate(180)), labels[2])
136
+ ] + [(img, 'Recon. with ' + str(label) + ' tokens') for img, label in zip(model_decompose, labels)]
137
+ return gallery_container, gallery_images, image
138
+
139
+ # Connect the inputs and outputs
140
+ submit_btn.click(
141
+ fn=update_outputs,
142
+ inputs=[input_image, show_gallery, slider, labels_input],
143
+ outputs=[gallery_container, gallery, output_image]
144
+ )
145
+
146
+ # Also update when checkbox changes
147
+ show_gallery.change(
148
+ fn=lambda value: gr.update(visible=value),
149
+ inputs=[show_gallery],
150
+ outputs=[gallery_container]
151
+ )
152
+
153
+ # Add examples
154
+ examples = [
155
+ ["examples/city.jpg", False, 4.0, "1,4,16,64,256"],
156
+ ["examples/food.jpg", True, 4.0, "1,4,16,64,256"],
157
+ ["examples/highland.webp", True, 4.0, "1,4,16,64,256"],
158
+ ]
159
+
160
+ gr.Examples(
161
+ examples=examples,
162
+ inputs=[input_image, show_gallery, slider, labels_input],
163
+ outputs=[gallery_container, gallery, output_image],
164
+ fn=update_outputs,
165
+ cache_examples=True
166
+ )
167
+
168
+ # Launch the demo
169
+ if __name__ == "__main__":
170
+ demo.launch()
examples/city.jpg ADDED
examples/food.jpg ADDED
examples/highland.webp ADDED
fid_stats/adm_in256_stats.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e842c756177806d210e891893914620bee2cf5b779a5613ba5af5145d7c85289
3
+ size 33563124
gen_demo.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ import time
5
+ from PIL import Image
6
+ import io
7
+
8
+ def apply_transformations(image, numbers_str):
9
+ """
10
+ Apply a series of transformations to an image based on a list of numbers.
11
+ Shows the progressive changes as each transformation is applied.
12
+ Returns both the current image and the full gallery of transformations.
13
+ """
14
+ try:
15
+ # Parse the input numbers
16
+ numbers = [float(n.strip()) for n in numbers_str.split(',') if n.strip()]
17
+ if not numbers:
18
+ return image, [(image, "Original Image")]
19
+
20
+ # Convert PIL Image to numpy array for OpenCV operations
21
+ img = np.array(image)
22
+
23
+ # Initialize the result list with the original image
24
+ results = [(image, "Original Image")]
25
+ current_image = image
26
+
27
+ # Apply transformations based on each number
28
+ for i, value in enumerate(numbers):
29
+ # Make a copy of the current numpy image
30
+ if i == 0:
31
+ current_img = img.copy()
32
+ else:
33
+ current_img = np.array(current_image)
34
+
35
+ # Apply different transformations based on the value
36
+ transformation_type = ""
37
+ if i % 5 == 0: # Brightness adjustment
38
+ # Scale value to reasonable brightness adjustment
39
+ brightness = max(min(value, 100), -100) # Limit between -100 and 100
40
+ current_img = cv2.addWeighted(current_img, 1, np.zeros_like(current_img), 0, brightness)
41
+ transformation_type = f"Brightness: {brightness:.1f}"
42
+
43
+ elif i % 5 == 1: # Rotation
44
+ # Scale value to reasonable rotation angle
45
+ angle = value % 360
46
+ h, w = current_img.shape[:2]
47
+ center = (w // 2, h // 2)
48
+ rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
49
+ current_img = cv2.warpAffine(current_img, rotation_matrix, (w, h))
50
+ transformation_type = f"Rotation: {angle:.1f}°"
51
+
52
+ elif i % 5 == 2: # Contrast adjustment
53
+ # Scale value to reasonable contrast adjustment
54
+ contrast = max(min(value / 10, 3), 0.5) # Limit between 0.5 and 3
55
+ current_img = cv2.convertScaleAbs(current_img, alpha=contrast, beta=0)
56
+ transformation_type = f"Contrast: {contrast:.1f}x"
57
+
58
+ elif i % 5 == 3: # Blur
59
+ # Scale value to reasonable blur kernel size
60
+ blur_amount = max(int(abs(value) % 20), 1)
61
+ if blur_amount % 2 == 0: # Ensure kernel size is odd
62
+ blur_amount += 1
63
+ current_img = cv2.GaussianBlur(current_img, (blur_amount, blur_amount), 0)
64
+ transformation_type = f"Blur: {blur_amount}px"
65
+
66
+ elif i % 5 == 4: # Hue shift (for color images)
67
+ if current_img.shape[-1] == 3: # Only for color images
68
+ # Convert to HSV
69
+ hsv_img = cv2.cvtColor(current_img, cv2.COLOR_RGB2HSV)
70
+ # Shift hue
71
+ hue_shift = int(value) % 180
72
+ hsv_img[:, :, 0] = (hsv_img[:, :, 0] + hue_shift) % 180
73
+ # Convert back to RGB
74
+ current_img = cv2.cvtColor(hsv_img, cv2.COLOR_HSV2RGB)
75
+ transformation_type = f"Hue Shift: {hue_shift}°"
76
+
77
+ # Convert back to PIL Image and add to results
78
+ current_image = Image.fromarray(current_img)
79
+
80
+ # Add to results with a label for the gallery
81
+ results.append((current_image, f"Step {i+1}: {transformation_type}"))
82
+
83
+ # (Progress updates removed)
84
+
85
+ # Add a small delay to make the progressive changes visible
86
+ time.sleep(4)
87
+
88
+ # Yield intermediate results for real-time updates
89
+ if i < len(numbers) - 1:
90
+ yield current_image, results
91
+
92
+ return current_image, results
93
+
94
+ except Exception as e:
95
+ error_msg = f"Error: {str(e)}"
96
+ return image, [(image, "Error")]
97
+
98
+ # Create Gradio Interface
99
+ with gr.Blocks() as demo:
100
+ gr.Markdown("# Image Transformation Demo")
101
+ gr.Markdown("Upload an image and provide a comma-separated list of numbers. The demo will apply a series of transformations to the image based on these numbers.")
102
+
103
+ with gr.Row():
104
+ with gr.Column(scale=1):
105
+ input_image = gr.Image(label="Input Image", type="pil")
106
+ numbers_input = gr.Textbox(label="Transformation Values (comma-separated numbers)",
107
+ placeholder="e.g., 50, -30, 1.5, 5, 90, 20",
108
+ value="30, 45, 1.5, 3, 60, -20, 90, 1.8, 7, 120")
109
+ transform_btn = gr.Button("Apply Transformations")
110
+
111
+ explanation = gr.Markdown("""
112
+ ## How the transformations work:
113
+
114
+ The numbers you input will be used to apply these transformations in sequence:
115
+ 1. First number: Brightness adjustment (-100 to 100)
116
+ 2. Second number: Rotation (degrees)
117
+ 3. Third number: Contrast adjustment (0.5 to 3)
118
+ 4. Fourth number: Blur (kernel size)
119
+ 5. Fifth number: Hue shift (color images only)
120
+
121
+ And the pattern repeats for longer lists of numbers.
122
+ """)
123
+
124
+ with gr.Column(scale=2):
125
+ with gr.Row():
126
+ current_image = gr.Image(label="Current Transformation", type="pil")
127
+ with gr.Row():
128
+ gallery = gr.Gallery(label="Transformation History", show_label=True, columns=4, rows=2, height="auto")
129
+
130
+ transform_btn.click(
131
+ fn=apply_transformations,
132
+ inputs=[input_image, numbers_input],
133
+ outputs=[current_image, gallery]
134
+ )
135
+
136
+ # Launch the app
137
+ demo.launch()
making_cache.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os, pdb, time
3
+ import torch_fidelity
4
+ import tqdm
5
+ import torch
6
+ import os.path as osp
7
+ import argparse
8
+ from omegaconf import OmegaConf
9
+ from paintmind.engine.util import instantiate_from_config
10
+
11
+
12
+ @torch.no_grad()
13
+ def caching():
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--cfg', type=str, default='configs/vit_vqgan.yaml')
16
+ args = parser.parse_args()
17
+
18
+ cfg_file = args.cfg
19
+ assert osp.exists(cfg_file)
20
+ config = OmegaConf.load(cfg_file)
21
+ dataset = instantiate_from_config(config.trainer.params.dataset)
22
+ model = instantiate_from_config(config.trainer.params.model)
23
+ dataloader = torch.utils.data.DataLoader(
24
+ dataset,
25
+ batch_size=config.trainer.params.batch_size,
26
+ shuffle=False,
27
+ num_workers=config.trainer.params.num_workers,
28
+ )
29
+ # Each batch will give us a (N, C, H, W) tensor of images
30
+ # We need to cache them and save them to a pth file
31
+ cache_save_file = config.trainer.params.latent_cache_file
32
+ cache = []
33
+ # import ipdb; ipdb.set_trace()
34
+ model.cuda()
35
+ model.eval()
36
+ for idx, batch in enumerate(tqdm.tqdm(dataloader)):
37
+ batch = batch[0].cuda()
38
+ latent = model.vae_encode(batch)
39
+ cache.append(latent.cpu())
40
+ cache = torch.cat(cache, dim=0)
41
+ torch.save(cache, cache_save_file)
42
+
43
+ if __name__ == '__main__':
44
+
45
+ caching()
46
+
paintmind/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .config import Config
2
+ from .version import __version__
paintmind/config.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json, pdb
2
+ from copy import deepcopy
3
+
4
+ class Config:
5
+ def __init__(self, config=None):
6
+ if config is not None:
7
+ self.from_dict(config)
8
+
9
+ def __repr__(self):
10
+ return str(self.to_json_string())
11
+
12
+ def to_dict(self):
13
+ return deepcopy(self.__dict__)
14
+
15
+ def to_json(self, path):
16
+ with open(path, 'w') as f:
17
+ json.dump(self.to_dict(), f, indent=2)
18
+
19
+ def to_json_string(self):
20
+ return json.dumps(self.to_dict(), indent=2)
21
+
22
+ def from_dict(self, dct):
23
+ self.clear()
24
+ for key, value in dct.items():
25
+ self.__dict__[key] = value
26
+
27
+ return self.to_dict()
28
+
29
+ def from_json(self, json_path):
30
+ with open(json_path, 'r') as f:
31
+ config = json.load(f)
32
+ self.from_dict(config)
33
+
34
+ return self.to_dict()
35
+
36
+ def clear(self):
37
+ del self.__dict__
38
+
39
+
40
+ vit_s_vqgan_config = {
41
+ 'n_embed' :8192,
42
+ 'embed_dim' :16,
43
+ 'beta' :0.25,
44
+ 'enc':{
45
+ 'image_size':320,
46
+ 'patch_size':8,
47
+ 'dim':512,
48
+ 'depth':8,
49
+ 'num_head':8,
50
+ 'mlp_dim':2048,
51
+ 'in_channels':3,
52
+ 'dim_head':64,
53
+ 'dropout':0.0,
54
+ },
55
+ 'dec':{
56
+ 'image_size':320,
57
+ 'patch_size':8,
58
+ 'dim':512,
59
+ 'depth':8,
60
+ 'num_head':8,
61
+ 'mlp_dim':2048,
62
+ 'out_channels':3,
63
+ 'dim_head':64,
64
+ 'dropout':0.0,
65
+ },
66
+ }
67
+ vit_m_vqgan_config = {
68
+ 'n_embed' :8192,
69
+ 'embed_dim' :32,
70
+ 'beta' :0.25,
71
+ 'enc':{
72
+ 'image_size':256,
73
+ 'patch_size':8,
74
+ 'dim': 1024,
75
+ 'depth': 16,
76
+ 'num_head':16,
77
+ 'mlp_dim':2048,
78
+ 'in_channels':3,
79
+ 'dim_head':64,
80
+ 'dropout':0.0,
81
+ },
82
+ 'dec':{
83
+ 'image_size':256,
84
+ 'patch_size':8,
85
+ 'dim':1024,
86
+ 'depth':16,
87
+ 'num_head':16,
88
+ 'mlp_dim':2048,
89
+ 'out_channels':3,
90
+ 'dim_head':64,
91
+ 'dropout':0.0,
92
+ },
93
+ }
94
+
95
+ pipeline_v1_config = {
96
+ 'stage1' :'vit-s-vqgan',
97
+ 't5' :'t5-l',
98
+ 'dim' :1024,
99
+ 'dim_head' :64,
100
+ 'mlp_dim' :4096,
101
+ 'num_head' :16,
102
+ 'depth' :12,
103
+ 'dropout' :0.1,
104
+ }
105
+
106
+ ver2cfg = {
107
+ 'vit-s-vqgan' : vit_s_vqgan_config,
108
+ 'vit-m-vqgan' : vit_m_vqgan_config,
109
+ 'paintmindv1' : pipeline_v1_config,
110
+ }
paintmind/engine/gpt_trainer.py ADDED
@@ -0,0 +1,892 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch
2
+ import os.path as osp
3
+ import cv2
4
+ import shutil
5
+ import numpy as np
6
+ import copy
7
+ import torch_fidelity
8
+ import torch.nn as nn
9
+ from tqdm.auto import tqdm
10
+ from collections import OrderedDict
11
+ from einops import rearrange
12
+ from accelerate import Accelerator
13
+ from .util import instantiate_from_config
14
+ from torchvision.utils import make_grid, save_image
15
+ from torch.utils.data import DataLoader, random_split, DistributedSampler, Sampler
16
+ from paintmind.utils.lr_scheduler import build_scheduler
17
+ from paintmind.utils.logger import SmoothedValue, MetricLogger, synchronize_processes, empty_cache
18
+ from paintmind.engine.misc import is_main_process, all_reduce_mean, concat_all_gather
19
+ from accelerate.utils import DistributedDataParallelKwargs, AutocastKwargs
20
+ from torch.optim import AdamW
21
+ from concurrent.futures import ThreadPoolExecutor
22
+ from paintmind.stage2.gpt import GPT_models
23
+ from paintmind.stage2.causaldit import CausalDiT_models
24
+ from paintmind.stage2.generate import generate, generate_causal_dit
25
+ from pathlib import Path
26
+ import time
27
+
28
+
29
+ def requires_grad(model, flag=True):
30
+ for p in model.parameters():
31
+ p.requires_grad = flag
32
+
33
+
34
+ def save_img(img, save_path):
35
+ img = np.clip(img.float().numpy().transpose([1, 2, 0]) * 255, 0, 255)
36
+ img = img.astype(np.uint8)[:, :, ::-1]
37
+ cv2.imwrite(save_path, img)
38
+
39
+ def save_img_batch(imgs, save_paths):
40
+ """Process and save multiple images at once using a thread pool."""
41
+ # Convert to numpy and prepare all images in one go
42
+ imgs = np.clip(imgs.float().numpy().transpose(0, 2, 3, 1) * 255, 0, 255).astype(np.uint8)
43
+ imgs = imgs[:, :, :, ::-1] # RGB to BGR for all images at once
44
+
45
+ # Use ProcessPoolExecutor which is generally better for CPU-bound tasks
46
+ # ThreadPoolExecutor is better for I/O-bound tasks like file saving
47
+ with ThreadPoolExecutor(max_workers=32) as pool:
48
+ # Submit all tasks at once
49
+ futures = [pool.submit(cv2.imwrite, path, img)
50
+ for path, img in zip(save_paths, imgs)]
51
+ # Wait for all tasks to complete
52
+ for future in futures:
53
+ future.result() # This will raise any exceptions that occurred
54
+
55
+ def get_fid_stats(real_dir, rec_dir, fid_stats):
56
+ stats = torch_fidelity.calculate_metrics(
57
+ input1=real_dir,
58
+ input2=rec_dir,
59
+ fid_statistics_file=fid_stats,
60
+ cuda=True,
61
+ isc=True,
62
+ fid=True,
63
+ kid=False,
64
+ prc=False,
65
+ verbose=False,
66
+ )
67
+ return stats
68
+
69
+
70
+ class EMAModel:
71
+ """Model Exponential Moving Average."""
72
+ def __init__(self, model, device, decay=0.999):
73
+ self.device = device
74
+ self.decay = decay
75
+ self.ema_params = OrderedDict(
76
+ (name, param.clone().detach().to(device))
77
+ for name, param in model.named_parameters()
78
+ if param.requires_grad
79
+ )
80
+
81
+ @torch.no_grad()
82
+ def update(self, model):
83
+ for name, param in model.named_parameters():
84
+ if param.requires_grad:
85
+ if name in self.ema_params:
86
+ self.ema_params[name].lerp_(param.data, 1 - self.decay)
87
+ else:
88
+ self.ema_params[name] = param.data.clone().detach()
89
+
90
+ def state_dict(self):
91
+ return self.ema_params
92
+
93
+ def load_state_dict(self, params):
94
+ self.ema_params = OrderedDict(
95
+ (name, param.clone().detach().to(self.device))
96
+ for name, param in params.items()
97
+ )
98
+
99
+ class CacheDataLoader:
100
+ """DataLoader-like interface for cached data with epoch-based shuffling."""
101
+ def __init__(self, slots, targets=None, batch_size=32, num_augs=1, seed=None):
102
+ self.slots = slots
103
+ self.targets = targets
104
+ self.batch_size = batch_size
105
+ self.num_augs = num_augs
106
+ self.seed = seed
107
+ self.epoch = 0
108
+ # Original dataset size (before augmentations)
109
+ self.num_samples = len(slots) // num_augs
110
+
111
+ def set_epoch(self, epoch):
112
+ """Set epoch for deterministic shuffling."""
113
+ self.epoch = epoch
114
+
115
+ def __len__(self):
116
+ """Return number of batches based on original dataset size."""
117
+ return self.num_samples // self.batch_size
118
+
119
+ def __iter__(self):
120
+ """Return random indices for current epoch."""
121
+ g = torch.Generator()
122
+ g.manual_seed(self.seed + self.epoch if self.seed is not None else self.epoch)
123
+
124
+ # Randomly sample indices from the entire augmented dataset
125
+ indices = torch.randint(
126
+ 0, len(self.slots),
127
+ (self.num_samples,),
128
+ generator=g
129
+ ).numpy()
130
+
131
+ # Yield batches of indices
132
+ for start in range(0, self.num_samples, self.batch_size):
133
+ end = min(start + self.batch_size, self.num_samples)
134
+ batch_indices = indices[start:end]
135
+ yield (
136
+ torch.from_numpy(self.slots[batch_indices]),
137
+ torch.from_numpy(self.targets[batch_indices])
138
+ )
139
+
140
+ class GPTTrainer(nn.Module):
141
+ def __init__(
142
+ self,
143
+ ae_model,
144
+ gpt_model,
145
+ dataset,
146
+ test_dataset=None,
147
+ test_only=False,
148
+ num_test_images=50000,
149
+ num_epoch=400,
150
+ eval_classes=[1, 7, 282, 604, 724, 207, 250, 751, 404, 850], # goldfish, cock, tiger cat, hourglass, ship, golden retriever, husky, race car, airliner, teddy bear
151
+ lr=None,
152
+ blr=1e-4,
153
+ cosine_lr=False,
154
+ lr_min=0,
155
+ warmup_epochs=100,
156
+ warmup_steps=None,
157
+ warmup_lr_init=0,
158
+ decay_steps=None,
159
+ batch_size=32,
160
+ cache_bs=8,
161
+ test_bs=100,
162
+ num_workers=0,
163
+ pin_memory=False,
164
+ max_grad_norm=None,
165
+ grad_accum_steps=1,
166
+ precision="bf16",
167
+ save_every=10000,
168
+ sample_every=1000,
169
+ fid_every=50000,
170
+ result_folder=None,
171
+ log_dir="./log",
172
+ steps=0,
173
+ cfg=1.75,
174
+ ae_cfg=1.5,
175
+ diff_cfg=2.0,
176
+ temperature=1.0,
177
+ cfg_schedule="constant",
178
+ diff_cfg_schedule="inv_linear",
179
+ train_num_slots=None,
180
+ test_num_slots=None,
181
+ eval_fid=False,
182
+ fid_stats=None,
183
+ enable_ema=False,
184
+ compile=False,
185
+ enable_cache_latents=True,
186
+ cache_dir='/dev/shm/slot_cache',
187
+ seed=42
188
+ ):
189
+ super().__init__()
190
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
191
+ self.accelerator = Accelerator(
192
+ kwargs_handlers=[kwargs],
193
+ mixed_precision="bf16",
194
+ gradient_accumulation_steps=grad_accum_steps,
195
+ log_with="tensorboard",
196
+ project_dir=log_dir,
197
+ )
198
+
199
+ self.ae_model = instantiate_from_config(ae_model)
200
+ if hasattr(ae_model.params, "ema_path") and ae_model.params.ema_path is not None:
201
+ ae_model_path = ae_model.params.ema_path
202
+ else:
203
+ ae_model_path = ae_model.params.ckpt_path
204
+ assert ae_model_path.endswith(".safetensors") or ae_model_path.endswith(".pt") or ae_model_path.endswith(".pth") or ae_model_path.endswith(".pkl")
205
+ assert osp.exists(ae_model_path), f"AE model checkpoint {ae_model_path} does not exist"
206
+ self._load_checkpoint(ae_model_path, self.ae_model)
207
+
208
+ self.ae_model.to(self.device)
209
+ for param in self.ae_model.parameters():
210
+ param.requires_grad = False
211
+ self.ae_model.eval()
212
+
213
+ self.model_name = gpt_model.target
214
+ if 'GPT' in gpt_model.target:
215
+ self.gpt_model = GPT_models[gpt_model.target](**gpt_model.params)
216
+ elif 'CausalDiT' in gpt_model.target:
217
+ self.gpt_model = CausalDiT_models[gpt_model.target](**gpt_model.params)
218
+ else:
219
+ raise ValueError(f"Unknown model type: {gpt_model.target}")
220
+ self.num_slots = ae_model.params.num_slots
221
+ self.slot_dim = ae_model.params.slot_dim
222
+
223
+ assert precision in ["bf16", "fp32"]
224
+ precision = "fp32"
225
+ if self.accelerator.is_main_process:
226
+ print("Overlooking specified precision and using autocast bf16...")
227
+ self.precision = precision
228
+
229
+ self.test_only = test_only
230
+ self.test_bs = test_bs
231
+ self.num_test_images = num_test_images
232
+ self.num_classes = gpt_model.params.num_classes
233
+
234
+ self.batch_size = batch_size
235
+ if not test_only:
236
+ self.train_ds = instantiate_from_config(dataset)
237
+ train_size = len(self.train_ds)
238
+ if self.accelerator.is_main_process:
239
+ print(f"train dataset size: {train_size}")
240
+
241
+ sampler = DistributedSampler(
242
+ self.train_ds,
243
+ num_replicas=self.accelerator.num_processes,
244
+ rank=self.accelerator.process_index,
245
+ shuffle=True,
246
+ )
247
+ self.train_dl = DataLoader(
248
+ self.train_ds,
249
+ batch_size=batch_size if not enable_cache_latents else cache_bs,
250
+ sampler=sampler,
251
+ num_workers=num_workers,
252
+ pin_memory=pin_memory,
253
+ drop_last=True,
254
+ )
255
+
256
+ effective_bs = batch_size * grad_accum_steps * self.accelerator.num_processes
257
+ if lr is None:
258
+ lr = blr * effective_bs / 256
259
+ if self.accelerator.is_main_process:
260
+ print(f"Effective batch size is {effective_bs}")
261
+
262
+ self.g_optim = self._creat_optimizer(weight_decay=0.05, learning_rate=lr, betas=(0.9, 0.95))
263
+ self.g_sched = self._create_scheduler(
264
+ cosine_lr, warmup_epochs, warmup_steps, num_epoch,
265
+ lr_min, warmup_lr_init, decay_steps
266
+ )
267
+ self.accelerator.register_for_checkpointing(self.g_sched)
268
+
269
+ self.steps = steps
270
+ self.loaded_steps = -1
271
+
272
+ # Prepare everything together
273
+ if not test_only:
274
+ self.gpt_model, self.g_optim, self.g_sched = self.accelerator.prepare(
275
+ self.gpt_model, self.g_optim, self.g_sched
276
+ )
277
+ else:
278
+ self.gpt_model = self.accelerator.prepare(self.gpt_model)
279
+
280
+ # assume _ori_model does not exist in checkpoints
281
+ if compile:
282
+ _model = self.accelerator.unwrap_model(self.gpt_model)
283
+ self.ae_model = torch.compile(self.ae_model, mode="reduce-overhead")
284
+ _model = torch.compile(_model, mode="reduce-overhead")
285
+
286
+ self.enable_ema = enable_ema
287
+ if self.enable_ema and not self.test_only: # when testing, we directly load the ema dict and skip here
288
+ self.ema_model = EMAModel(self.accelerator.unwrap_model(self.gpt_model), self.device)
289
+ self.accelerator.register_for_checkpointing(self.ema_model)
290
+
291
+ self._load_checkpoint(gpt_model.params.ckpt_path)
292
+ if self.test_only:
293
+ self.steps = self.loaded_steps
294
+
295
+ self.num_epoch = num_epoch
296
+ self.save_every = save_every
297
+ self.samp_every = sample_every
298
+ self.fid_every = fid_every
299
+ self.max_grad_norm = max_grad_norm
300
+
301
+ self.eval_classes = eval_classes
302
+ self.cfg = cfg
303
+ self.ae_cfg = ae_cfg
304
+ self.diff_cfg = diff_cfg
305
+ self.cfg_schedule = cfg_schedule
306
+ self.diff_cfg_schedule = diff_cfg_schedule
307
+ self.temperature = temperature
308
+ self.train_num_slots = train_num_slots
309
+ self.test_num_slots = test_num_slots
310
+ if self.train_num_slots is not None:
311
+ self.train_num_slots = min(self.train_num_slots, self.num_slots)
312
+ else:
313
+ self.train_num_slots = self.num_slots
314
+ if self.test_num_slots is not None:
315
+ self.num_slots_to_gen = min(self.test_num_slots, self.train_num_slots)
316
+ else:
317
+ self.num_slots_to_gen = self.train_num_slots
318
+ self.eval_fid = eval_fid
319
+ if eval_fid:
320
+ assert fid_stats is not None
321
+ self.fid_stats = fid_stats
322
+
323
+ self.result_folder = result_folder
324
+ self.model_saved_dir = os.path.join(result_folder, "models")
325
+ os.makedirs(self.model_saved_dir, exist_ok=True)
326
+
327
+ self.image_saved_dir = os.path.join(result_folder, "images")
328
+ os.makedirs(self.image_saved_dir, exist_ok=True)
329
+
330
+ self.cache_dir = Path(cache_dir)
331
+ self.enable_cache_latents = enable_cache_latents
332
+ self.seed = seed
333
+ self.cache_loader = None
334
+
335
+ @property
336
+ def device(self):
337
+ return self.accelerator.device
338
+
339
+ def _creat_optimizer(self, weight_decay, learning_rate, betas):
340
+ # start with all of the candidate parameters
341
+ param_dict = {pn: p for pn, p in self.gpt_model.named_parameters()}
342
+ # filter out those that do not require grad
343
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
344
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
345
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
346
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
347
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
348
+ optim_groups = [
349
+ {'params': decay_params, 'weight_decay': weight_decay},
350
+ {'params': nodecay_params, 'weight_decay': 0.0}
351
+ ]
352
+ num_decay_params = sum(p.numel() for p in decay_params)
353
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
354
+ if self.accelerator.is_main_process:
355
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
356
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
357
+ optimizer = AdamW(optim_groups, lr=learning_rate, betas=betas)
358
+ return optimizer
359
+
360
+ def _create_scheduler(self, cosine_lr, warmup_epochs, warmup_steps, num_epoch, lr_min, warmup_lr_init, decay_steps):
361
+ if warmup_epochs is not None:
362
+ warmup_steps = warmup_epochs * len(self.train_dl)
363
+ else:
364
+ assert warmup_steps is not None
365
+
366
+ scheduler = build_scheduler(
367
+ self.g_optim,
368
+ num_epoch,
369
+ len(self.train_dl),
370
+ lr_min,
371
+ warmup_steps,
372
+ warmup_lr_init,
373
+ decay_steps,
374
+ cosine_lr, # if not cosine_lr, then use step_lr (warmup, then fix)
375
+ )
376
+ return scheduler
377
+
378
+ def _load_state_dict(self, state_dict, model):
379
+ """Helper to load a state dict with proper prefix handling."""
380
+ if 'state_dict' in state_dict:
381
+ state_dict = state_dict['state_dict']
382
+ # Remove '_orig_mod' prefix if present
383
+ state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
384
+ missing, unexpected = model.load_state_dict(
385
+ state_dict, strict=False
386
+ )
387
+ if self.accelerator.is_main_process:
388
+ print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}")
389
+
390
+ def _load_safetensors(self, path, model):
391
+ """Helper to load a safetensors checkpoint."""
392
+ from safetensors.torch import safe_open
393
+ with safe_open(path, framework="pt", device="cpu") as f:
394
+ state_dict = {k: f.get_tensor(k) for k in f.keys()}
395
+ self._load_state_dict(state_dict, model)
396
+
397
+ def _load_checkpoint(self, ckpt_path=None, model=None):
398
+ if ckpt_path is None or not osp.exists(ckpt_path):
399
+ return
400
+
401
+ if model is None:
402
+ model = self.accelerator.unwrap_model(self.gpt_model)
403
+
404
+ if osp.isdir(ckpt_path):
405
+ # ckpt_path is something like 'path/to/models/step10/'
406
+ self.loaded_steps = int(
407
+ ckpt_path.split("step")[-1].split("/")[0]
408
+ )
409
+ if not self.test_only:
410
+ self.accelerator.load_state(ckpt_path)
411
+ else:
412
+ if self.enable_ema:
413
+ model_path = osp.join(ckpt_path, "custom_checkpoint_1.pkl")
414
+ if osp.exists(model_path):
415
+ state_dict = torch.load(model_path, map_location="cpu")
416
+ self._load_state_dict(state_dict, model)
417
+ if self.accelerator.is_main_process:
418
+ print(f"Loaded ema model from {model_path}")
419
+ else:
420
+ model_path = osp.join(ckpt_path, "model.safetensors")
421
+ if osp.exists(model_path):
422
+ self._load_safetensors(model_path, model)
423
+ else:
424
+ # ckpt_path is something like 'path/to/models/step10.pt'
425
+ if ckpt_path.endswith(".safetensors"):
426
+ self._load_safetensors(ckpt_path, model)
427
+ else:
428
+ state_dict = torch.load(ckpt_path, map_location="cpu")
429
+ self._load_state_dict(state_dict, model)
430
+
431
+ if self.accelerator.is_main_process:
432
+ print(f"Loaded checkpoint from {ckpt_path}")
433
+
434
+ def _build_cache(self):
435
+ """Build cache for slots and targets."""
436
+ rank = self.accelerator.process_index
437
+ world_size = self.accelerator.num_processes
438
+
439
+ # Clean up any existing cache files first
440
+ slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap"
441
+ targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap"
442
+
443
+ if slots_file.exists():
444
+ os.remove(slots_file)
445
+ if targets_file.exists():
446
+ os.remove(targets_file)
447
+
448
+ dataset_size = len(self.train_dl.dataset)
449
+ shard_size = dataset_size // world_size
450
+
451
+ # Detect number of augmentations from first batch
452
+ with torch.no_grad():
453
+ sample_batch = next(iter(self.train_dl))
454
+ img, _ = sample_batch
455
+ num_augs = img.shape[1] if len(img.shape) == 5 else 1
456
+
457
+ print(f"Rank {rank}: Creating new cache with {num_augs} augmentations per image...")
458
+ os.makedirs(self.cache_dir, exist_ok=True)
459
+ slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap"
460
+ targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap"
461
+
462
+ # Create memory-mapped files
463
+ slots_mmap = np.memmap(
464
+ slots_file,
465
+ dtype='float32',
466
+ mode='w+',
467
+ shape=(shard_size * num_augs, self.train_num_slots, self.slot_dim)
468
+ )
469
+
470
+ targets_mmap = np.memmap(
471
+ targets_file,
472
+ dtype='int64',
473
+ mode='w+',
474
+ shape=(shard_size * num_augs,)
475
+ )
476
+
477
+ # Cache data
478
+ with torch.no_grad():
479
+ for i, batch in enumerate(tqdm(
480
+ self.train_dl,
481
+ desc=f"Rank {rank}: Caching data",
482
+ disable=not self.accelerator.is_local_main_process
483
+ )):
484
+ imgs, targets = batch
485
+ if len(imgs.shape) == 5: # [B, num_augs, C, H, W]
486
+ B, A, C, H, W = imgs.shape
487
+ imgs = imgs.view(-1, C, H, W) # [B*num_augs, C, H, W]
488
+ targets = targets.unsqueeze(1).expand(-1, A).reshape(-1) # [B*num_augs]
489
+
490
+ # Split imgs into n chunks
491
+ num_splits = num_augs
492
+ split_size = imgs.shape[0] // num_splits
493
+ imgs_splits = torch.split(imgs, split_size)
494
+ targets_splits = torch.split(targets, split_size)
495
+
496
+ start_idx = i * self.train_dl.batch_size * num_augs
497
+
498
+ for split_idx, (img_split, targets_split) in enumerate(zip(imgs_splits, targets_splits)):
499
+ img_split = img_split.to(self.device, non_blocking=True)
500
+ slots_split = self.ae_model.encode_slots(img_split)[:, :self.train_num_slots, :]
501
+
502
+ split_start = start_idx + (split_idx * split_size)
503
+ split_end = split_start + img_split.shape[0]
504
+
505
+ # Write directly to mmap files
506
+ slots_mmap[split_start:split_end] = slots_split.cpu().numpy()
507
+ targets_mmap[split_start:split_end] = targets_split.numpy()
508
+
509
+ # Close the mmap files
510
+ del slots_mmap
511
+ del targets_mmap
512
+
513
+ # Reopen in read mode
514
+ self.cached_latents = np.memmap(
515
+ slots_file,
516
+ dtype='float32',
517
+ mode='r',
518
+ shape=(shard_size * num_augs, self.train_num_slots, self.slot_dim)
519
+ )
520
+
521
+ self.cached_targets = np.memmap(
522
+ targets_file,
523
+ dtype='int64',
524
+ mode='r',
525
+ shape=(shard_size * num_augs,)
526
+ )
527
+
528
+ # Store the number of augmentations for the cache loader
529
+ self.num_augs = num_augs
530
+
531
+ def _setup_cache(self):
532
+ """Setup cache if enabled."""
533
+ self._build_cache()
534
+ self.accelerator.wait_for_everyone()
535
+
536
+ # Initialize cache loader if cache exists
537
+ if self.cached_latents is not None:
538
+ self.cache_loader = CacheDataLoader(
539
+ slots=self.cached_latents,
540
+ targets=self.cached_targets,
541
+ batch_size=self.batch_size,
542
+ num_augs=self.num_augs,
543
+ seed=self.seed + self.accelerator.process_index
544
+ )
545
+
546
+ def __del__(self):
547
+ """Cleanup cache files."""
548
+ if self.enable_cache_latents:
549
+ rank = self.accelerator.process_index
550
+ world_size = self.accelerator.num_processes
551
+
552
+ # Clean up slots cache
553
+ slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap"
554
+ if slots_file.exists():
555
+ os.remove(slots_file)
556
+
557
+ # Clean up targets cache
558
+ targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap"
559
+ if targets_file.exists():
560
+ os.remove(targets_file)
561
+
562
+ def _train_step(self, slots, targets=None):
563
+ """Execute single training step."""
564
+
565
+ with self.accelerator.accumulate(self.gpt_model):
566
+ with self.accelerator.autocast():
567
+ loss = self.gpt_model(slots, targets)
568
+
569
+ self.accelerator.backward(loss)
570
+ if self.accelerator.sync_gradients and self.max_grad_norm is not None:
571
+ self.accelerator.clip_grad_norm_(self.gpt_model.parameters(), self.max_grad_norm)
572
+ self.g_optim.step()
573
+ if self.g_sched is not None:
574
+ self.g_sched.step_update(self.steps)
575
+ self.g_optim.zero_grad()
576
+
577
+ # Update EMA model if enabled
578
+ if self.enable_ema:
579
+ self.ema_model.update(self.accelerator.unwrap_model(self.gpt_model))
580
+
581
+ return loss
582
+
583
+ def _train_epoch_cached(self, epoch, logger):
584
+ """Train one epoch using cached data."""
585
+ self.cache_loader.set_epoch(epoch)
586
+ header = f'Epoch: [{epoch}/{self.num_epoch}]'
587
+
588
+ for batch in logger.log_every(self.cache_loader, 20, header):
589
+ slots, targets = (b.to(self.device, non_blocking=True) for b in batch)
590
+
591
+ self.steps += 1
592
+
593
+ if self.steps == 1:
594
+ print(f"Training batch size: {len(slots)}")
595
+ print(f"Hello from index {self.accelerator.local_process_index}")
596
+
597
+ loss = self._train_step(slots, targets)
598
+ self._handle_periodic_ops(loss, logger)
599
+
600
+ def _train_epoch_uncached(self, epoch, logger):
601
+ """Train one epoch using raw data."""
602
+ header = f'Epoch: [{epoch}/{self.num_epoch}]'
603
+
604
+ for batch in logger.log_every(self.train_dl, 20, header):
605
+ img, targets = (b.to(self.device, non_blocking=True) for b in batch)
606
+
607
+ self.steps += 1
608
+
609
+ if self.steps == 1:
610
+ print(f"Training batch size: {img.size(0)}")
611
+ print(f"Hello from index {self.accelerator.local_process_index}")
612
+
613
+ slots = self.ae_model.encode_slots(img)[:, :self.train_num_slots, :]
614
+ loss = self._train_step(slots, targets)
615
+ self._handle_periodic_ops(loss, logger)
616
+
617
+ def _handle_periodic_ops(self, loss, logger):
618
+ """Handle periodic operations and logging."""
619
+ logger.update(loss=loss.item())
620
+ logger.update(lr=self.g_optim.param_groups[0]["lr"])
621
+
622
+ if self.steps % self.save_every == 0:
623
+ self.save()
624
+
625
+ if (self.steps % self.samp_every == 0) or (self.eval_fid and self.steps % self.fid_every == 0):
626
+ empty_cache()
627
+ self.evaluate()
628
+ self.accelerator.wait_for_everyone()
629
+ empty_cache()
630
+
631
+ def _save_config(self, config):
632
+ """Save configuration file."""
633
+ if config is not None and self.accelerator.is_main_process:
634
+ import shutil
635
+ from omegaconf import OmegaConf
636
+
637
+ if isinstance(config, str) and osp.exists(config):
638
+ shutil.copy(config, osp.join(self.result_folder, "config.yaml"))
639
+ else:
640
+ config_save_path = osp.join(self.result_folder, "config.yaml")
641
+ OmegaConf.save(config, config_save_path)
642
+
643
+ def _should_skip_epoch(self, epoch):
644
+ """Check if epoch should be skipped due to loaded checkpoint."""
645
+ loader = self.train_dl if not self.enable_cache_latents else self.cache_loader
646
+ if ((epoch + 1) * len(loader)) <= self.loaded_steps:
647
+ if self.accelerator.is_main_process:
648
+ print(f"Epoch {epoch} is skipped because it is loaded from ckpt")
649
+ self.steps += len(loader)
650
+ return True
651
+
652
+ if self.steps < self.loaded_steps:
653
+ for _ in loader:
654
+ self.steps += 1
655
+ if self.steps >= self.loaded_steps:
656
+ break
657
+ return False
658
+
659
+ def train(self, config=None):
660
+ """Main training loop."""
661
+ # Initial setup
662
+ n_parameters = sum(p.numel() for p in self.parameters() if p.requires_grad)
663
+ if self.accelerator.is_main_process:
664
+ print(f"number of learnable parameters: {n_parameters//1e6}M")
665
+
666
+ self._save_config(config)
667
+ self.accelerator.init_trackers("gpt")
668
+
669
+ # Handle test-only mode
670
+ if self.test_only:
671
+ empty_cache()
672
+ self.evaluate()
673
+ self.accelerator.wait_for_everyone()
674
+ empty_cache()
675
+ return
676
+
677
+ # Setup cache if enabled
678
+ if self.enable_cache_latents:
679
+ self._setup_cache()
680
+
681
+ # Training loop
682
+ for epoch in range(self.num_epoch):
683
+ if self._should_skip_epoch(epoch):
684
+ continue
685
+
686
+ self.gpt_model.train()
687
+ logger = MetricLogger(delimiter=" ")
688
+ logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
689
+
690
+ # Choose training path based on cache availability
691
+ if self.enable_cache_latents:
692
+ self._train_epoch_cached(epoch, logger)
693
+ else:
694
+ self._train_epoch_uncached(epoch, logger)
695
+
696
+ # Synchronize and log epoch stats
697
+ # logger.synchronize_between_processes()
698
+ # if self.accelerator.is_main_process:
699
+ # print("Averaged stats:", logger)
700
+
701
+ # Finish training
702
+ self.accelerator.end_training()
703
+ self.save()
704
+ if self.accelerator.is_main_process:
705
+ print("Train finished!")
706
+
707
+ def save(self):
708
+ self.accelerator.wait_for_everyone()
709
+ self.accelerator.save_state(
710
+ os.path.join(self.model_saved_dir, f"step{self.steps}")
711
+ )
712
+
713
+ @torch.no_grad()
714
+ def evaluate(self, use_ema=True):
715
+ self.gpt_model.eval()
716
+ unwraped_gpt_model = self.accelerator.unwrap_model(self.gpt_model)
717
+ # switch to ema params, only when eval_fid is True
718
+ use_ema = use_ema and self.enable_ema and self.eval_fid and not self.test_only
719
+ if use_ema:
720
+ if hasattr(self, "ema_model"):
721
+ model_without_ddp = self.accelerator.unwrap_model(self.gpt_model)
722
+ model_state_dict = copy.deepcopy(model_without_ddp.state_dict())
723
+ ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
724
+ for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
725
+ if "nested_sampler" in name:
726
+ continue
727
+ ema_state_dict[name] = self.ema_model.state_dict()[name]
728
+ if self.accelerator.is_main_process:
729
+ print("Switch to ema")
730
+ model_without_ddp.load_state_dict(ema_state_dict)
731
+ else:
732
+ print("EMA model not found, using original model")
733
+ use_ema = False
734
+
735
+ generate_fn = generate if 'GPT' in self.model_name else generate_causal_dit
736
+ if not self.test_only:
737
+ classes = torch.tensor(self.eval_classes, device=self.device)
738
+ with self.accelerator.autocast():
739
+ slots = generate_fn(unwraped_gpt_model, classes, self.num_slots_to_gen, cfg_scale=self.cfg, diff_cfg=self.diff_cfg, cfg_schedule=self.cfg_schedule, diff_cfg_schedule=self.diff_cfg_schedule, temperature=self.temperature)
740
+ if self.num_slots_to_gen < self.num_slots:
741
+ null_slots = self.ae_model.dit.null_cond.expand(slots.shape[0], -1, -1)
742
+ null_slots = null_slots[:, self.num_slots_to_gen:, :]
743
+ slots = torch.cat([slots, null_slots], dim=1)
744
+ imgs = self.ae_model.sample(slots, targets=classes, cfg=self.ae_cfg) # targets are not used for now
745
+
746
+ imgs = concat_all_gather(imgs)
747
+ if self.accelerator.num_processes > 16:
748
+ imgs = imgs[:16*len(self.eval_classes)]
749
+ imgs = imgs.detach().cpu()
750
+ grid = make_grid(
751
+ imgs, nrow=len(self.eval_classes), normalize=True, value_range=(0, 1)
752
+ )
753
+ if self.accelerator.is_main_process:
754
+ save_image(
755
+ grid,
756
+ os.path.join(
757
+ self.image_saved_dir, f"step{self.steps}_aecfg-{self.ae_cfg}_cfg-{self.cfg_schedule}-{self.cfg}_diffcfg-{self.diff_cfg_schedule}-{self.diff_cfg}_slots{self.num_slots_to_gen}_temp{self.temperature}.jpg"
758
+ ),
759
+ )
760
+ if self.eval_fid and (self.test_only or (self.steps % self.fid_every == 0)):
761
+ # Create output directory (only on main process)
762
+ save_folder = os.path.join(self.image_saved_dir, f"gen_step{self.steps}_aecfg-{self.ae_cfg}_cfg-{self.cfg_schedule}-{self.cfg}_diffcfg-{self.diff_cfg_schedule}-{self.diff_cfg}_slots{self.num_slots_to_gen}_temp{self.temperature}")
763
+ if self.accelerator.is_main_process:
764
+ os.makedirs(save_folder, exist_ok=True)
765
+
766
+ # Setup for distributed generation
767
+ world_size = self.accelerator.num_processes
768
+ local_rank = self.accelerator.process_index
769
+ batch_size = self.test_bs
770
+
771
+ # Create balanced class distribution
772
+ num_classes = self.num_classes
773
+ images_per_class = self.num_test_images // num_classes
774
+ class_labels = np.repeat(np.arange(num_classes), images_per_class)
775
+
776
+ # Shuffle the class labels to ensure random ordering
777
+ np.random.shuffle(class_labels)
778
+
779
+ total_images = len(class_labels)
780
+
781
+ padding_size = world_size * batch_size - (total_images % (world_size * batch_size))
782
+ class_labels = np.pad(class_labels, (0, padding_size), 'constant')
783
+ padded_total_images = len(class_labels)
784
+
785
+ # Distribute workload across GPUs
786
+ images_per_gpu = padded_total_images // world_size
787
+ start_idx = local_rank * images_per_gpu
788
+ end_idx = min(start_idx + images_per_gpu, padded_total_images)
789
+ local_class_labels = class_labels[start_idx:end_idx]
790
+ local_num_steps = len(local_class_labels) // batch_size
791
+
792
+ if self.accelerator.is_main_process:
793
+ print(f"Generating {total_images} images ({images_per_class} per class) across {world_size} GPUs")
794
+
795
+ used_time = 0
796
+ gen_img_cnt = 0
797
+
798
+ for i in range(local_num_steps):
799
+ if self.accelerator.is_main_process and i % 10 == 0:
800
+ print(f"Generation step {i}/{local_num_steps}")
801
+
802
+ # Get and pad labels for current batch
803
+ batch_start = i * batch_size
804
+ batch_end = batch_start + batch_size
805
+ labels = local_class_labels[batch_start:batch_end]
806
+
807
+ # Convert to tensors and track real vs padding
808
+ labels = torch.tensor(labels, device=self.device)
809
+
810
+ # Generate images
811
+ self.accelerator.wait_for_everyone()
812
+ start_time = time.time()
813
+ with torch.no_grad():
814
+ with self.accelerator.autocast():
815
+ slots = generate_fn(unwraped_gpt_model, labels, self.num_slots_to_gen,
816
+ cfg_scale=self.cfg, diff_cfg=self.diff_cfg,
817
+ cfg_schedule=self.cfg_schedule, diff_cfg_schedule=self.diff_cfg_schedule,
818
+ temperature=self.temperature)
819
+ if self.num_slots_to_gen < self.num_slots:
820
+ null_slots = self.ae_model.dit.null_cond.expand(slots.shape[0], -1, -1)
821
+ null_slots = null_slots[:, self.num_slots_to_gen:, :]
822
+ slots = torch.cat([slots, null_slots], dim=1)
823
+ imgs = self.ae_model.sample(slots, targets=labels, cfg=self.ae_cfg)
824
+
825
+ samples_in_batch = min(batch_size * world_size, total_images - gen_img_cnt)
826
+
827
+ # Update timing stats
828
+ used_time += time.time() - start_time
829
+ gen_img_cnt += samples_in_batch
830
+ if self.accelerator.is_main_process and i % 10 == 0:
831
+ print(f"Avg generation time: {used_time/gen_img_cnt:.5f} sec/image")
832
+
833
+ gathered_imgs = concat_all_gather(imgs)
834
+ gathered_imgs = gathered_imgs[:samples_in_batch]
835
+
836
+ # Save images (only on main process)
837
+ if self.accelerator.is_main_process:
838
+ real_imgs = gathered_imgs.detach().cpu()
839
+
840
+ save_paths = [
841
+ os.path.join(save_folder, f"{str(idx).zfill(5)}.png")
842
+ for idx in range(gen_img_cnt - samples_in_batch, gen_img_cnt)
843
+ ]
844
+ save_img_batch(real_imgs, save_paths)
845
+
846
+ # Calculate metrics (only on main process)
847
+ self.accelerator.wait_for_everyone()
848
+ if self.accelerator.is_main_process:
849
+ generated_files = len(os.listdir(save_folder))
850
+ print(f"Generated {generated_files} images out of {total_images} expected")
851
+
852
+ metrics_dict = get_fid_stats(save_folder, None, self.fid_stats)
853
+ fid = metrics_dict["frechet_inception_distance"]
854
+ inception_score = metrics_dict["inception_score_mean"]
855
+
856
+ metric_prefix = "fid_ema" if use_ema else "fid"
857
+ isc_prefix = "isc_ema" if use_ema else "isc"
858
+
859
+ self.accelerator.log({
860
+ metric_prefix: fid,
861
+ isc_prefix: inception_score,
862
+ "gpt_cfg": self.cfg,
863
+ "ae_cfg": self.ae_cfg,
864
+ "diff_cfg": self.diff_cfg,
865
+ "cfg_schedule": self.cfg_schedule,
866
+ "diff_cfg_schedule": self.diff_cfg_schedule,
867
+ "temperature": self.temperature,
868
+ "num_slots": self.test_num_slots if self.test_num_slots is not None else self.train_num_slots
869
+ }, step=self.steps)
870
+
871
+ # Print comprehensive CFG information
872
+ cfg_info = (
873
+ f"{'EMA ' if use_ema else ''}CFG params: "
874
+ f"gpt_cfg={self.cfg}, ae_cfg={self.ae_cfg}, diff_cfg={self.diff_cfg}, "
875
+ f"cfg_schedule={self.cfg_schedule}, diff_cfg_schedule={self.diff_cfg_schedule}, "
876
+ f"num_slots={self.test_num_slots if self.test_num_slots is not None else self.train_num_slots}, "
877
+ f"temperature={self.temperature}"
878
+ )
879
+ print(cfg_info)
880
+ print(f"FID: {fid:.2f}, ISC: {inception_score:.2f}")
881
+
882
+ # Cleanup
883
+ shutil.rmtree(save_folder)
884
+
885
+ # back to no ema
886
+ if use_ema:
887
+ if self.accelerator.is_main_process:
888
+ print("Switch back from ema")
889
+ model_without_ddp.load_state_dict(model_state_dict)
890
+
891
+ self.gpt_model.train()
892
+
paintmind/engine/misc.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import socket
2
+ import os, sys, pdb
3
+ from torch import inf
4
+ import os.path as osp
5
+ from pathlib import Path
6
+ import builtins, datetime
7
+ import torch.distributed as dist
8
+ import os, sys, time, torch, copy, pdb
9
+ from collections import defaultdict, deque
10
+
11
+ def print_available_port():
12
+
13
+ return _find_free_port()
14
+
15
+ def _find_free_port():
16
+
17
+ sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
18
+ # Binding to port 0 will cause the OS to find an available port for us
19
+ sock.bind(("", 0))
20
+ port = sock.getsockname()[1]
21
+ sock.close()
22
+ # NOTE: there is still a chance the port could be taken by other processes.
23
+ return port
24
+
25
+ def ensure_dir(dirpath):
26
+
27
+ if not osp.exists(dirpath):
28
+ os.makedirs(dirpath, exist_ok=True)
29
+
30
+ def setup_for_distributed(is_master):
31
+ """
32
+ This function disables printing when not in master process
33
+ """
34
+ builtin_print = builtins.print
35
+
36
+ def print(*args, **kwargs):
37
+ force = kwargs.pop('force', False)
38
+ force = force or (get_world_size() > 8)
39
+ if is_master or force:
40
+ now = datetime.datetime.now().time()
41
+ builtin_print('[{}] '.format(now), end='') # print with time stamp
42
+ builtin_print(*args, **kwargs)
43
+
44
+ builtins.print = print
45
+
46
+
47
+ def is_dist_avail_and_initialized():
48
+ if not dist.is_available():
49
+ return False
50
+ if not dist.is_initialized():
51
+ return False
52
+ return True
53
+
54
+
55
+ def get_world_size():
56
+ if not is_dist_avail_and_initialized():
57
+ return 1
58
+ return dist.get_world_size()
59
+
60
+
61
+ def get_rank():
62
+ if not is_dist_avail_and_initialized():
63
+ return 0
64
+ return dist.get_rank()
65
+
66
+
67
+ def concat_all_gather(tensor):
68
+ """
69
+ Performs all_gather operation on the provided tensors.
70
+ *** Warning ***: torch.distributed.all_gather has no gradient.
71
+ """
72
+ tensors_gather = [torch.ones_like(tensor)
73
+ for _ in range(torch.distributed.get_world_size())]
74
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
75
+
76
+ output = torch.cat(tensors_gather, dim=0)
77
+ return output
78
+
79
+
80
+ def is_main_process():
81
+ return get_rank() == 0
82
+
83
+
84
+ def save_on_master(*args, **kwargs):
85
+ if is_main_process():
86
+ torch.save(*args, **kwargs)
87
+
88
+
89
+ def init_distributed_mode(args):
90
+
91
+ if args.dist_on_itp:
92
+ args.rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
93
+ args.world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
94
+ args.gpu = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
95
+ assert isinstance(args.port, int) & (args.port > 0) & (args.port < 1<<30)
96
+ port = _find_free_port()
97
+ # args.dist_url = "tcp://%s:%s" % (port, os.environ['MASTER_PORT'])
98
+ args.dist_url = f'tcp://127.0.0.1:{port}'
99
+ os.environ['LOCAL_RANK'] = str(args.gpu)
100
+ os.environ['RANK'] = str(args.rank)
101
+ os.environ['WORLD_SIZE'] = str(args.world_size)
102
+ # ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT", "LOCAL_RANK"]
103
+ elif 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
104
+ args.rank = int(os.environ["RANK"])
105
+ args.world_size = int(os.environ['WORLD_SIZE'])
106
+ args.gpu = int(os.environ['LOCAL_RANK'])
107
+ elif 'SLURM_PROCID' in os.environ:
108
+ args.rank = int(os.environ['SLURM_PROCID'])
109
+ args.gpu = args.rank % torch.cuda.device_count()
110
+ else:
111
+ print('Not using distributed mode')
112
+ setup_for_distributed(is_master=True) # hack
113
+ args.distributed = False
114
+ return
115
+
116
+ args.distributed = True
117
+
118
+ torch.cuda.set_device(args.gpu)
119
+ args.dist_backend = 'nccl'
120
+ print('| distributed init (rank {}): {}, gpu {}'.format(
121
+ args.rank, args.dist_url, args.gpu), flush=True)
122
+ torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
123
+ world_size=args.world_size, rank=args.rank)
124
+ torch.distributed.barrier()
125
+ setup_for_distributed(args.rank == 0)
126
+
127
+ class NativeScalerWithGradNormCount:
128
+ state_dict_key = "amp_scaler"
129
+
130
+ def __init__(self):
131
+ self._scaler = torch.cuda.amp.GradScaler()
132
+
133
+ def __call__(self, loss, optimizer, clip_grad=None, parameters=None, create_graph=False, update_grad=True):
134
+ self._scaler.scale(loss).backward(create_graph=create_graph)
135
+ if update_grad:
136
+ if clip_grad is not None:
137
+ assert parameters is not None
138
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
139
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
140
+ else:
141
+ self._scaler.unscale_(optimizer)
142
+ norm = get_grad_norm_(parameters)
143
+ self._scaler.step(optimizer)
144
+ self._scaler.update()
145
+ else:
146
+ norm = None
147
+ return norm
148
+
149
+ def state_dict(self):
150
+ return self._scaler.state_dict()
151
+
152
+ def load_state_dict(self, state_dict):
153
+ self._scaler.load_state_dict(state_dict)
154
+
155
+
156
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
157
+ if isinstance(parameters, torch.Tensor):
158
+ parameters = [parameters]
159
+ parameters = [p for p in parameters if p.grad is not None]
160
+ norm_type = float(norm_type)
161
+ if len(parameters) == 0:
162
+ return torch.tensor(0.)
163
+ device = parameters[0].grad.device
164
+ if norm_type == inf:
165
+ total_norm = max(p.grad.detach().abs().max().to(device) for p in parameters)
166
+ else:
167
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]), norm_type)
168
+ return total_norm
169
+
170
+
171
+ def save_model(args, epoch, model, model_without_ddp, optimizer, loss_scaler, ema_params=None):
172
+ output_dir = Path(args.output_dir)
173
+ epoch_name = str(epoch)
174
+ if loss_scaler is not None:
175
+ checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
176
+
177
+ # ema
178
+ if ema_params is not None:
179
+ ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
180
+ for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
181
+ assert name in ema_state_dict
182
+ ema_state_dict[name] = ema_params[i]
183
+ else:
184
+ ema_state_dict = None
185
+
186
+ for checkpoint_path in checkpoint_paths:
187
+ to_save = {
188
+ 'model': model_without_ddp.state_dict(),
189
+ 'model_ema': ema_state_dict,
190
+ 'optimizer': optimizer.state_dict(),
191
+ 'epoch': epoch,
192
+ 'scaler': loss_scaler.state_dict(),
193
+ 'args': args,
194
+ }
195
+
196
+ save_on_master(to_save, checkpoint_path)
197
+ else:
198
+ client_state = {'epoch': epoch}
199
+ model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
200
+
201
+
202
+ def save_model_last(args, epoch, model, model_without_ddp, optimizer, loss_scaler, ema_params=None):
203
+
204
+ output_dir = Path(args.output_dir)
205
+ epoch_name = 'last'
206
+ if loss_scaler is not None:
207
+ checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
208
+
209
+ # ema
210
+ if ema_params is not None:
211
+ ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
212
+ for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
213
+ assert name in ema_state_dict
214
+ ema_state_dict[name] = ema_params[i]
215
+ else:
216
+ ema_state_dict = None
217
+
218
+ for checkpoint_path in checkpoint_paths:
219
+ to_save = {
220
+ 'model': model_without_ddp.state_dict(),
221
+ 'model_ema': ema_state_dict,
222
+ 'optimizer': optimizer.state_dict(),
223
+ 'epoch': epoch,
224
+ 'scaler': loss_scaler.state_dict(),
225
+ 'args': args,
226
+ }
227
+
228
+ save_on_master(to_save, checkpoint_path)
229
+ else:
230
+ client_state = {'epoch': epoch}
231
+ model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
232
+
233
+
234
+ def load_model(args, model_without_ddp, optimizer, loss_scaler):
235
+
236
+ if osp.exists(osp.join(args.resume, "checkpoint-last.pth")):
237
+ resume_path = osp.join(args.resume, "checkpoint-last.pth")
238
+ else:
239
+ resume_path = args.resume
240
+ if args.resume:
241
+ checkpoint = torch.load(resume_path, map_location='cpu')
242
+ model_without_ddp.load_state_dict(checkpoint['model'])
243
+ print("Resume checkpoint %s" % resume_path)
244
+ if 'optimizer' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'evaluate') and args.evaluate):
245
+ optimizer.load_state_dict(checkpoint['optimizer'])
246
+ args.start_epoch = checkpoint['epoch'] + 1
247
+ if 'scaler' in checkpoint:
248
+ loss_scaler.load_state_dict(checkpoint['scaler'])
249
+ print("With optim & sched!")
250
+
251
+ def all_reduce_mean(x):
252
+
253
+ world_size = get_world_size()
254
+ if world_size > 1:
255
+ x_reduce = torch.tensor(x).cuda()
256
+ dist.all_reduce(x_reduce)
257
+ x_reduce /= world_size
258
+ return x_reduce.item()
259
+ else:
260
+ return x
paintmind/engine/trainer.py ADDED
@@ -0,0 +1,695 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch
2
+ import os.path as osp
3
+ import cv2
4
+ import shutil
5
+ import numpy as np
6
+ import copy
7
+ import torch_fidelity
8
+ import torch.nn as nn
9
+ from tqdm.auto import tqdm
10
+ from collections import OrderedDict
11
+ from einops import rearrange
12
+ from accelerate import Accelerator
13
+ from .util import instantiate_from_config
14
+ from torchvision.utils import make_grid, save_image
15
+ from torch.utils.data import DataLoader, random_split, DistributedSampler
16
+ from paintmind.utils.lr_scheduler import build_scheduler
17
+ from paintmind.utils.logger import SmoothedValue, MetricLogger, synchronize_processes, empty_cache
18
+ from paintmind.engine.misc import is_main_process, all_reduce_mean, concat_all_gather
19
+ from accelerate.utils import DistributedDataParallelKwargs, AutocastKwargs
20
+ from torch.optim import AdamW
21
+ from concurrent.futures import ThreadPoolExecutor
22
+ from torchmetrics.functional.image import (
23
+ peak_signal_noise_ratio as psnr,
24
+ structural_similarity_index_measure as ssim
25
+ )
26
+
27
+ def requires_grad(model, flag=True):
28
+ for p in model.parameters():
29
+ p.requires_grad = flag
30
+
31
+
32
+ def save_img(img, save_path):
33
+ img = np.clip(img.numpy().transpose([1, 2, 0]) * 255, 0, 255)
34
+ img = img.astype(np.uint8)[:, :, ::-1]
35
+ cv2.imwrite(save_path, img)
36
+
37
+ def save_img_batch(imgs, save_paths):
38
+ """Process and save multiple images at once using a thread pool."""
39
+ # Convert to numpy and prepare all images in one go
40
+ imgs = np.clip(imgs.numpy().transpose(0, 2, 3, 1) * 255, 0, 255).astype(np.uint8)
41
+ imgs = imgs[:, :, :, ::-1] # RGB to BGR for all images at once
42
+
43
+ # Use ProcessPoolExecutor which is generally better for CPU-bound tasks
44
+ # ThreadPoolExecutor is better for I/O-bound tasks like file saving
45
+ with ThreadPoolExecutor(max_workers=32) as pool:
46
+ # Submit all tasks at once
47
+ futures = [pool.submit(cv2.imwrite, path, img)
48
+ for path, img in zip(save_paths, imgs)]
49
+ # Wait for all tasks to complete
50
+ for future in futures:
51
+ future.result() # This will raise any exceptions that occurred
52
+
53
+ def get_fid_stats(real_dir, rec_dir, fid_stats):
54
+ stats = torch_fidelity.calculate_metrics(
55
+ input1=rec_dir,
56
+ input2=real_dir,
57
+ fid_statistics_file=fid_stats,
58
+ cuda=True,
59
+ isc=True,
60
+ fid=True,
61
+ kid=False,
62
+ prc=False,
63
+ verbose=False,
64
+ )
65
+ return stats
66
+
67
+
68
+ class EMAModel:
69
+ """Model Exponential Moving Average."""
70
+ def __init__(self, model, device, decay=0.999):
71
+ self.device = device
72
+ self.decay = decay
73
+ self.ema_params = OrderedDict(
74
+ (name, param.clone().detach().to(device))
75
+ for name, param in model.named_parameters()
76
+ if param.requires_grad
77
+ )
78
+
79
+ @torch.no_grad()
80
+ def update(self, model):
81
+ for name, param in model.named_parameters():
82
+ if param.requires_grad:
83
+ if name in self.ema_params:
84
+ self.ema_params[name].lerp_(param.data, 1 - self.decay)
85
+ else:
86
+ self.ema_params[name] = param.data.clone().detach()
87
+
88
+ def state_dict(self):
89
+ return self.ema_params
90
+
91
+ def load_state_dict(self, params):
92
+ self.ema_params = OrderedDict(
93
+ (name, param.clone().detach().to(self.device))
94
+ for name, param in params.items()
95
+ )
96
+
97
+ class DiffusionTrainer(nn.Module):
98
+ def __init__(
99
+ self,
100
+ model,
101
+ dataset,
102
+ test_dataset=None,
103
+ test_only=False,
104
+ num_epoch=400,
105
+ valid_size=32,
106
+ lr=None,
107
+ blr=1e-4,
108
+ cosine_lr=True,
109
+ lr_min=0,
110
+ warmup_epochs=100,
111
+ warmup_steps=None,
112
+ warmup_lr_init=0,
113
+ decay_steps=None,
114
+ batch_size=32,
115
+ eval_bs=32,
116
+ test_bs=64,
117
+ num_workers=0,
118
+ pin_memory=False,
119
+ max_grad_norm=None,
120
+ grad_accum_steps=1,
121
+ precision="bf16",
122
+ save_every=10000,
123
+ sample_every=1000,
124
+ fid_every=50000,
125
+ result_folder=None,
126
+ log_dir="./log",
127
+ steps=0,
128
+ cfg=1.0,
129
+ test_num_slots=None,
130
+ eval_fid=False,
131
+ fid_stats=None,
132
+ enable_ema=False,
133
+ use_multi_epochs_dataloader=False,
134
+ compile=False,
135
+ overfit=False,
136
+ making_cache=False,
137
+ cache_mode=False,
138
+ latent_cache_file=None,
139
+ ):
140
+ super().__init__()
141
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
142
+ self.accelerator = Accelerator(
143
+ kwargs_handlers=[kwargs],
144
+ mixed_precision="bf16",
145
+ gradient_accumulation_steps=grad_accum_steps,
146
+ log_with="tensorboard",
147
+ project_dir=log_dir,
148
+ )
149
+
150
+ self.model = instantiate_from_config(model)
151
+ self.num_slots = model.params.num_slots
152
+
153
+ assert precision in ["bf16", "fp32"]
154
+ precision = "fp32"
155
+ if self.accelerator.is_main_process:
156
+ print("Overlooking specified precision and using autocast bf16...")
157
+ self.precision = precision
158
+
159
+ if test_dataset is not None:
160
+ test_dataset = instantiate_from_config(test_dataset)
161
+ self.test_ds = test_dataset
162
+
163
+ # Calculate padded dataset size to ensure even distribution
164
+ total_size = len(test_dataset)
165
+ world_size = self.accelerator.num_processes
166
+ padding_size = world_size * test_bs - (total_size % (world_size * test_bs))
167
+ self.test_dataset_size = total_size
168
+
169
+ # Create a padded dataset wrapper
170
+ class PaddedDataset(torch.utils.data.Dataset):
171
+ def __init__(self, dataset, padding_size):
172
+ self.dataset = dataset
173
+ self.padding_size = padding_size
174
+
175
+ def __len__(self):
176
+ return len(self.dataset) + self.padding_size
177
+
178
+ def __getitem__(self, idx):
179
+ if idx < len(self.dataset):
180
+ return self.dataset[idx]
181
+ return self.dataset[0]
182
+
183
+ self.test_ds = PaddedDataset(self.test_ds, padding_size)
184
+ self.test_dl = DataLoader(
185
+ self.test_ds,
186
+ batch_size=test_bs,
187
+ num_workers=num_workers,
188
+ pin_memory=pin_memory,
189
+ shuffle=False,
190
+ drop_last=True,
191
+ )
192
+ if self.accelerator.is_main_process:
193
+ print(f"test dataset size: {len(test_dataset)}, test batch size: {test_bs}")
194
+ else:
195
+ self.test_dl = None
196
+ self.test_only = test_only
197
+
198
+ if not test_only:
199
+ dataset = instantiate_from_config(dataset)
200
+ train_size = len(dataset) - valid_size
201
+ self.train_ds, self.valid_ds = random_split(
202
+ dataset,
203
+ [train_size, valid_size],
204
+ generator=torch.Generator().manual_seed(42),
205
+ )
206
+ if self.accelerator.is_main_process:
207
+ print(f"train dataset size: {train_size}, valid dataset size: {valid_size}")
208
+
209
+ sampler = DistributedSampler(
210
+ self.train_ds,
211
+ num_replicas=self.accelerator.num_processes,
212
+ rank=self.accelerator.process_index,
213
+ shuffle=True,
214
+ )
215
+ self.train_dl = DataLoader(
216
+ self.train_ds,
217
+ batch_size=batch_size,
218
+ sampler=sampler,
219
+ num_workers=num_workers,
220
+ pin_memory=pin_memory,
221
+ drop_last=True,
222
+ )
223
+ self.valid_dl = DataLoader(
224
+ self.valid_ds,
225
+ batch_size=eval_bs,
226
+ shuffle=False,
227
+ num_workers=num_workers,
228
+ pin_memory=pin_memory,
229
+ )
230
+
231
+ effective_bs = batch_size * grad_accum_steps * self.accelerator.num_processes
232
+ if lr is None:
233
+ lr = blr * effective_bs / 256
234
+ if self.accelerator.is_main_process:
235
+ print(f"Effective batch size is {effective_bs}")
236
+
237
+ params = filter(lambda p: p.requires_grad, self.model.parameters())
238
+ self.g_optim = AdamW(params, lr=lr, betas=(0.9, 0.95), weight_decay=0)
239
+ self.g_sched = self._create_scheduler(
240
+ cosine_lr, warmup_epochs, warmup_steps, num_epoch,
241
+ lr_min, warmup_lr_init, decay_steps
242
+ )
243
+ if self.g_sched is not None:
244
+ self.accelerator.register_for_checkpointing(self.g_sched)
245
+
246
+ self.steps = steps
247
+ self.loaded_steps = -1
248
+
249
+ # Prepare everything together
250
+ if not test_only:
251
+ self.model, self.g_optim, self.g_sched = self.accelerator.prepare(
252
+ self.model, self.g_optim, self.g_sched
253
+ )
254
+ else:
255
+ self.model, self.test_dl = self.accelerator.prepare(self.model, self.test_dl)
256
+
257
+ if compile:
258
+ _model = self.accelerator.unwrap_model(self.model)
259
+ _model.vae = torch.compile(_model.vae, mode="reduce-overhead")
260
+ _model.dit = torch.compile(_model.dit, mode="reduce-overhead")
261
+ # _model.encoder = torch.compile(_model.encoder, mode="reduce-overhead") # nan loss when compiled together with dit, no idea why
262
+ _model.encoder2slot = torch.compile(_model.encoder2slot, mode="reduce-overhead")
263
+
264
+ self.enable_ema = enable_ema
265
+ if self.enable_ema and not self.test_only: # when testing, we directly load the ema dict and skip here
266
+ self.ema_model = EMAModel(self.accelerator.unwrap_model(self.model), self.device)
267
+ self.accelerator.register_for_checkpointing(self.ema_model)
268
+
269
+ self._load_checkpoint(model.params.ckpt_path)
270
+ if self.test_only:
271
+ self.steps = self.loaded_steps
272
+
273
+ self.num_epoch = num_epoch
274
+ self.save_every = save_every
275
+ self.samp_every = sample_every
276
+ self.fid_every = fid_every
277
+ self.max_grad_norm = max_grad_norm
278
+
279
+ self.cache_mode = cache_mode
280
+
281
+ self.cfg = cfg
282
+ self.test_num_slots = test_num_slots
283
+ if self.test_num_slots is not None:
284
+ self.test_num_slots = min(self.test_num_slots, self.num_slots)
285
+ else:
286
+ self.test_num_slots = self.num_slots
287
+ eval_fid = eval_fid or model.params.eval_fid # legacy
288
+ self.eval_fid = eval_fid
289
+ if eval_fid:
290
+ if fid_stats is None:
291
+ fid_stats = model.params.fid_stats # legacy
292
+ assert fid_stats is not None
293
+ assert test_dataset is not None
294
+ self.fid_stats = fid_stats
295
+
296
+ self.use_vq = model.params.use_vq if hasattr(model.params, "use_vq") else False
297
+ self.vq_beta = model.params.code_beta if hasattr(model.params, "code_beta") else 0.25
298
+
299
+ self.result_folder = result_folder
300
+ self.model_saved_dir = os.path.join(result_folder, "models")
301
+ os.makedirs(self.model_saved_dir, exist_ok=True)
302
+
303
+ self.image_saved_dir = os.path.join(result_folder, "images")
304
+ os.makedirs(self.image_saved_dir, exist_ok=True)
305
+
306
+ @property
307
+ def device(self):
308
+ return self.accelerator.device
309
+
310
+ def _create_scheduler(self, cosine_lr, warmup_epochs, warmup_steps, num_epoch, lr_min, warmup_lr_init, decay_steps):
311
+ if warmup_epochs is not None:
312
+ warmup_steps = warmup_epochs * len(self.train_dl)
313
+ else:
314
+ assert warmup_steps is not None
315
+
316
+ scheduler = build_scheduler(
317
+ self.g_optim,
318
+ num_epoch,
319
+ len(self.train_dl),
320
+ lr_min,
321
+ warmup_steps,
322
+ warmup_lr_init,
323
+ decay_steps,
324
+ cosine_lr, # if not cosine_lr, then use step_lr (warmup, then fix)
325
+ )
326
+ return scheduler
327
+
328
+ def _load_state_dict(self, state_dict):
329
+ """Helper to load a state dict with proper prefix handling."""
330
+ if 'state_dict' in state_dict:
331
+ state_dict = state_dict['state_dict']
332
+ # Remove '_orig_mod' prefix if present
333
+ state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
334
+ missing, unexpected = self.accelerator.unwrap_model(self.model).load_state_dict(
335
+ state_dict, strict=False
336
+ )
337
+ if self.accelerator.is_main_process:
338
+ print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}")
339
+
340
+ def _load_safetensors(self, path):
341
+ """Helper to load a safetensors checkpoint."""
342
+ from safetensors.torch import safe_open
343
+ with safe_open(path, framework="pt", device="cpu") as f:
344
+ state_dict = {k: f.get_tensor(k) for k in f.keys()}
345
+ self._load_state_dict(state_dict)
346
+
347
+ def _load_checkpoint(self, ckpt_path=None):
348
+ if ckpt_path is None or not osp.exists(ckpt_path):
349
+ return
350
+
351
+ if osp.isdir(ckpt_path):
352
+ # ckpt_path is something like 'path/to/models/step10/'
353
+ self.loaded_steps = int(
354
+ ckpt_path.split("step")[-1].split("/")[0]
355
+ )
356
+ if not self.test_only:
357
+ self.accelerator.load_state(ckpt_path)
358
+ else:
359
+ if self.enable_ema:
360
+ model_path = osp.join(ckpt_path, "custom_checkpoint_1.pkl")
361
+ if osp.exists(model_path):
362
+ state_dict = torch.load(model_path, map_location="cpu")
363
+ self._load_state_dict(state_dict)
364
+ if self.accelerator.is_main_process:
365
+ print(f"Loaded ema model from {model_path}")
366
+ else:
367
+ model_path = osp.join(ckpt_path, "model.safetensors")
368
+ if osp.exists(model_path):
369
+ self._load_safetensors(model_path)
370
+ else:
371
+ # ckpt_path is something like 'path/to/models/step10.pt'
372
+ if ckpt_path.endswith(".safetensors"):
373
+ self._load_safetensors(ckpt_path)
374
+ else:
375
+ state_dict = torch.load(ckpt_path)
376
+ self._load_state_dict(state_dict)
377
+ if self.accelerator.is_main_process:
378
+ print(f"Loaded checkpoint from {ckpt_path}")
379
+
380
+ def train(self, config=None):
381
+ n_parameters = sum(p.numel() for p in self.parameters() if p.requires_grad)
382
+ if self.accelerator.is_main_process:
383
+ print(f"number of learnable parameters: {n_parameters//1e6}M")
384
+ if config is not None:
385
+ # save the config
386
+ import shutil
387
+ from omegaconf import OmegaConf
388
+
389
+ if isinstance(config, str) and osp.exists(config):
390
+ # If it's a path, copy the file to config.yaml
391
+ shutil.copy(config, osp.join(self.result_folder, "config.yaml"))
392
+ else:
393
+ # If it's an OmegaConf object, dump it
394
+ config_save_path = osp.join(self.result_folder, "config.yaml")
395
+ OmegaConf.save(config, config_save_path)
396
+
397
+ self.accelerator.init_trackers("vqgan")
398
+
399
+ if self.test_only:
400
+ empty_cache()
401
+ self.evaluate()
402
+ self.accelerator.wait_for_everyone()
403
+ empty_cache()
404
+ return
405
+
406
+ for epoch in range(self.num_epoch):
407
+ if ((epoch + 1) * len(self.train_dl)) <= self.loaded_steps:
408
+ if self.accelerator.is_main_process:
409
+ print(f"Epoch {epoch} is skipped because it is loaded from ckpt")
410
+ self.steps += len(self.train_dl)
411
+ continue
412
+
413
+ if self.steps < self.loaded_steps:
414
+ for _ in self.train_dl:
415
+ self.steps += 1
416
+ if self.steps >= self.loaded_steps:
417
+ break
418
+
419
+
420
+ self.accelerator.unwrap_model(self.model).current_epoch = epoch
421
+ self.model.train() # Set model to training mode
422
+
423
+ logger = MetricLogger(delimiter=" ")
424
+ logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
425
+ header = 'Epoch: [{}/{}]'.format(epoch, self.num_epoch)
426
+ print_freq = 20
427
+ for data_iter_step, batch in enumerate(logger.log_every(self.train_dl, print_freq, header)):
428
+ # Move batch to device once
429
+ if isinstance(batch, tuple) or isinstance(batch, list):
430
+ batch = tuple(b.to(self.device, non_blocking=True) for b in batch)
431
+ if self.cache_mode:
432
+ img, latent, targets = batch[0], batch[1], batch[2]
433
+ img = img.to(self.device, non_blocking=True)
434
+ latent = latent.to(self.device, non_blocking=True)
435
+ targets = targets.to(self.device, non_blocking=True)
436
+ else:
437
+ latent = None
438
+ img, targets = batch[0], batch[1]
439
+ img = img.to(self.device, non_blocking=True)
440
+ targets = targets.to(self.device, non_blocking=True)
441
+ else:
442
+ img = batch
443
+ latent = None
444
+
445
+ self.steps += 1
446
+
447
+ with self.accelerator.accumulate(self.model):
448
+ with self.accelerator.autocast():
449
+ if self.steps == 1:
450
+ print(f"Training batch size: {img.size(0)}")
451
+ print(f"Hello from index {self.accelerator.local_process_index}")
452
+ losses = self.model(img, targets, latents=latent, epoch=epoch)
453
+ # combine
454
+ loss = sum([v for _, v in losses.items()])
455
+ diff_loss = losses["diff_loss"]
456
+ if self.use_vq:
457
+ vq_loss = losses["vq_loss"]
458
+
459
+ self.accelerator.backward(loss)
460
+ if self.accelerator.sync_gradients and self.max_grad_norm is not None:
461
+ self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
462
+ self.accelerator.unwrap_model(self.model).cancel_gradients_encoder(epoch)
463
+ self.g_optim.step()
464
+ if self.g_sched is not None:
465
+ self.g_sched.step_update(self.steps)
466
+ self.g_optim.zero_grad()
467
+
468
+ # synchronize_processes()
469
+
470
+ # update ema with state dict
471
+ if self.enable_ema:
472
+ self.ema_model.update(self.accelerator.unwrap_model(self.model))
473
+
474
+ logger.update(diff_loss=diff_loss.item())
475
+ if self.use_vq:
476
+ logger.update(vq_loss=vq_loss.item() / self.vq_beta)
477
+ if 'kl_loss' in losses:
478
+ logger.update(kl_loss=losses["kl_loss"].item())
479
+ if 'repa_loss' in losses:
480
+ logger.update(repa_loss=losses["repa_loss"].item())
481
+ logger.update(lr=self.g_optim.param_groups[0]["lr"])
482
+
483
+ if self.steps % self.save_every == 0:
484
+ self.save()
485
+
486
+ if (self.steps % self.samp_every == 0) or (self.steps % self.fid_every == 0):
487
+ empty_cache()
488
+ self.evaluate()
489
+ self.accelerator.wait_for_everyone()
490
+ empty_cache()
491
+
492
+ # omitted all_gather here
493
+ # write_dict = dict(epoch=epoch)
494
+ # write_dict.update(diff_loss=diff_loss.item())
495
+ # if "kl_loss" in losses:
496
+ # write_dict.update(kl_loss=losses["kl_loss"].item())
497
+ # if self.use_vq:
498
+ # write_dict.update(vq_loss=vq_loss.item() / self.vq_beta)
499
+ # write_dict.update(lr=self.g_optim.param_groups[0]["lr"])
500
+ # self.accelerator.log(write_dict, step=self.steps)
501
+
502
+ logger.synchronize_between_processes()
503
+ if self.accelerator.is_main_process:
504
+ print("Averaged stats:", logger)
505
+
506
+ self.accelerator.end_training()
507
+ self.save()
508
+ if self.accelerator.is_main_process:
509
+ print("Train finished!")
510
+
511
+ def save(self):
512
+ self.accelerator.wait_for_everyone()
513
+ self.accelerator.save_state(
514
+ os.path.join(self.model_saved_dir, f"step{self.steps}")
515
+ )
516
+
517
+ @torch.no_grad()
518
+ def evaluate(self, use_ema=True):
519
+ self.model.eval()
520
+ # switch to ema params, only when eval_fid is True
521
+ use_ema = use_ema and self.enable_ema and self.eval_fid and not self.test_only
522
+ # use_ema = False
523
+ if use_ema:
524
+ if hasattr(self, "ema_model"):
525
+ model_without_ddp = self.accelerator.unwrap_model(self.model)
526
+ model_state_dict = copy.deepcopy(model_without_ddp.state_dict())
527
+ ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
528
+ for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
529
+ if "nested_sampler" in name:
530
+ continue
531
+ if name in self.ema_model.state_dict():
532
+ ema_state_dict[name] = self.ema_model.state_dict()[name]
533
+ if self.accelerator.is_main_process:
534
+ print("Switch to ema")
535
+ msg = model_without_ddp.load_state_dict(ema_state_dict)
536
+ if self.accelerator.is_main_process:
537
+ print(msg)
538
+ else:
539
+ print("EMA model not found, using original model")
540
+ use_ema = False
541
+
542
+ if not self.test_only:
543
+ with tqdm(
544
+ self.valid_dl,
545
+ dynamic_ncols=True,
546
+ disable=not self.accelerator.is_main_process,
547
+ ) as valid_dl:
548
+ for batch_i, batch in enumerate(valid_dl):
549
+ if isinstance(batch, tuple) or isinstance(batch, list):
550
+ img, targets = batch[0], batch[1]
551
+ else:
552
+ img = batch
553
+
554
+ with self.accelerator.autocast():
555
+ rec = self.model(img, targets, sample=True, inference_with_n_slots=self.test_num_slots, cfg=1.0)
556
+ imgs_and_recs = torch.stack((img.to(rec.device), rec), dim=0)
557
+ imgs_and_recs = rearrange(imgs_and_recs, "r b ... -> (b r) ...")
558
+ imgs_and_recs = imgs_and_recs.detach().cpu().float()
559
+
560
+ grid = make_grid(
561
+ imgs_and_recs, nrow=6, normalize=True, value_range=(0, 1)
562
+ )
563
+ if self.accelerator.is_main_process:
564
+ save_image(
565
+ grid,
566
+ os.path.join(
567
+ self.image_saved_dir, f"step_{self.steps}_slots{self.test_num_slots}_{batch_i}.jpg"
568
+ ),
569
+ )
570
+
571
+ if self.cfg != 1.0:
572
+ with self.accelerator.autocast():
573
+ rec = self.model(img, targets, sample=True, inference_with_n_slots=self.test_num_slots, cfg=self.cfg)
574
+
575
+ imgs_and_recs = torch.stack((img.to(rec.device), rec), dim=0)
576
+ imgs_and_recs = rearrange(imgs_and_recs, "r b ... -> (b r) ...")
577
+ imgs_and_recs = imgs_and_recs.detach().cpu().float()
578
+
579
+ grid = make_grid(
580
+ imgs_and_recs, nrow=6, normalize=True, value_range=(0, 1)
581
+ )
582
+ if self.accelerator.is_main_process:
583
+ save_image(
584
+ grid,
585
+ os.path.join(
586
+ self.image_saved_dir, f"step_{self.steps}_cfg_{self.cfg}_slots{self.test_num_slots}_{batch_i}.jpg"
587
+ ),
588
+ )
589
+ if (self.eval_fid and self.test_dl is not None) and (self.test_only or (self.steps % self.fid_every == 0)):
590
+ # Create output directories
591
+ if self.test_dataset_size > 10000:
592
+ real_dir = "./dataset/imagenet/val256"
593
+ else:
594
+ real_dir = "./dataset/coco/val2017_256"
595
+ rec_dir = os.path.join(self.image_saved_dir, f"rec_step{self.steps}_slots{self.test_num_slots}")
596
+ os.makedirs(rec_dir, exist_ok=True)
597
+
598
+ if self.cfg != 1.0:
599
+ rec_cfg_dir = os.path.join(self.image_saved_dir, f"rec_step{self.steps}_cfg_{self.cfg}_slots{self.test_num_slots}")
600
+ os.makedirs(rec_cfg_dir, exist_ok=True)
601
+
602
+ def process_batch(cfg_value, save_dir, header):
603
+ logger = MetricLogger(delimiter=" ")
604
+ print_freq = 5
605
+ psnr_values = []
606
+ ssim_values = []
607
+ total_processed = 0
608
+
609
+ for batch_i, batch in enumerate(logger.log_every(self.test_dl, print_freq, header)):
610
+ imgs, targets = (batch[0], batch[1]) if isinstance(batch, (tuple, list)) else (batch, None)
611
+
612
+ # Skip processing if we've already processed all real samples
613
+ if total_processed >= self.test_dataset_size:
614
+ break
615
+
616
+ imgs = imgs.to(self.device, non_blocking=True)
617
+ if targets is not None:
618
+ targets = targets.to(self.device, non_blocking=True)
619
+
620
+ with self.accelerator.autocast():
621
+ recs = self.model(imgs, targets, sample=True, inference_with_n_slots=self.test_num_slots, cfg=cfg_value)
622
+
623
+ psnr_val = psnr(recs, imgs, data_range=1.0)
624
+ ssim_val = ssim(recs, imgs, data_range=1.0)
625
+
626
+ recs = concat_all_gather(recs).detach()
627
+ psnr_val = concat_all_gather(psnr_val.view(1))
628
+ ssim_val = concat_all_gather(ssim_val.view(1))
629
+
630
+ # Remove padding after gathering from all GPUs
631
+ samples_in_batch = min(
632
+ recs.size(0), # Always use the gathered size
633
+ self.test_dataset_size - total_processed
634
+ )
635
+ recs = recs[:samples_in_batch]
636
+ psnr_val = psnr_val[:samples_in_batch]
637
+ ssim_val = ssim_val[:samples_in_batch]
638
+ psnr_values.append(psnr_val)
639
+ ssim_values.append(ssim_val)
640
+
641
+ if self.accelerator.is_main_process:
642
+ rec_paths = [os.path.join(save_dir, f"step_{self.steps}_slots{self.test_num_slots}_{batch_i}_{j}_rec_cfg_{cfg_value}_slots{self.test_num_slots}.png")
643
+ for j in range(recs.size(0))]
644
+ save_img_batch(recs.cpu(), rec_paths)
645
+
646
+ total_processed += samples_in_batch
647
+
648
+ self.accelerator.wait_for_everyone()
649
+
650
+ return torch.cat(psnr_values).mean(), torch.cat(ssim_values).mean()
651
+
652
+ # Helper function to calculate and log metrics
653
+ def calculate_and_log_metrics(real_dir, rec_dir, cfg_value, psnr_val, ssim_val):
654
+ if self.accelerator.is_main_process:
655
+ metrics_dict = get_fid_stats(real_dir, rec_dir, self.fid_stats)
656
+ fid = metrics_dict["frechet_inception_distance"]
657
+ inception_score = metrics_dict["inception_score_mean"]
658
+
659
+ metric_prefix = "fid_ema" if use_ema else "fid"
660
+ isc_prefix = "isc_ema" if use_ema else "isc"
661
+ self.accelerator.log({
662
+ metric_prefix: fid,
663
+ isc_prefix: inception_score,
664
+ f"psnr_{'ema' if use_ema else 'test'}": psnr_val,
665
+ f"ssim_{'ema' if use_ema else 'test'}": ssim_val,
666
+ "cfg": cfg_value
667
+ }, step=self.steps)
668
+
669
+ print(f"{'EMA ' if use_ema else ''}{f'CFG: {cfg_value}'} "
670
+ f"FID: {fid:.2f}, ISC: {inception_score:.2f}, "
671
+ f"PSNR: {psnr_val:.2f}, SSIM: {ssim_val:.4f}")
672
+
673
+ # Process without CFG
674
+ if self.cfg == 1.0 or not self.test_only:
675
+ psnr_val, ssim_val = process_batch(1.0, rec_dir, 'Testing: w/o CFG')
676
+ calculate_and_log_metrics(real_dir, rec_dir, 1.0, psnr_val, ssim_val)
677
+
678
+ # Process with CFG if needed
679
+ if self.cfg != 1.0:
680
+ psnr_val, ssim_val = process_batch(self.cfg, rec_cfg_dir, 'Testing: w/ CFG')
681
+ calculate_and_log_metrics(real_dir, rec_cfg_dir, self.cfg, psnr_val, ssim_val)
682
+
683
+ # Cleanup
684
+ if self.accelerator.is_main_process:
685
+ shutil.rmtree(rec_dir)
686
+ if self.cfg != 1.0:
687
+ shutil.rmtree(rec_cfg_dir)
688
+
689
+ # back to no ema
690
+ if use_ema:
691
+ if self.accelerator.is_main_process:
692
+ print("Switch back from ema")
693
+ model_without_ddp.load_state_dict(model_state_dict)
694
+
695
+ self.model.train()
paintmind/engine/util.py ADDED
@@ -0,0 +1,572 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import os.path as osp
3
+ import torch_fidelity
4
+ from PIL import Image
5
+ from tqdm import tqdm
6
+ import pickle as pkl
7
+ import os, hashlib, pdb
8
+ from pathlib import Path
9
+ from torch import Tensor
10
+ import torch, torchvision
11
+ from einops import rearrange
12
+ from omegaconf import OmegaConf
13
+ import torch.distributed as dist
14
+ from typing import List, Optional
15
+ from torchvision import transforms
16
+ from io import BytesIO as Bytes2Data
17
+ from smart_open import open
18
+ from .misc import is_main_process, get_rank
19
+ import importlib, datetime, requests, time, shutil
20
+ from collections import defaultdict, deque, OrderedDict
21
+ from dotwiz import DotWiz
22
+
23
+ URL_MAP = {
24
+ "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
25
+ }
26
+
27
+ CKPT_MAP = {
28
+ "vgg_lpips": "vgg.pth"
29
+ }
30
+
31
+ MD5_MAP = {
32
+ "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
33
+ }
34
+
35
+ def disabled_train(self, mode=True):
36
+ """Overwrite model.train with this function to make sure train/eval mode
37
+ does not change anymore."""
38
+ return self
39
+
40
+ def customized_collate_fn(batch):
41
+
42
+ collate_fn = {}
43
+ if len(batch) < 2:
44
+ for key, value in batch[0].items():
45
+ collate_fn[key] = [value]
46
+ else:
47
+
48
+ for i, dd in enumerate(batch):
49
+ if i < 1:
50
+ for key, value in dd.items():
51
+ collate_fn[key] = [value]
52
+ else:
53
+ for key, value in dd.items():
54
+ collate_fn[key].append(value)
55
+
56
+ return collate_fn
57
+
58
+
59
+ def trivial_batch_collator(batch):
60
+ """
61
+ A batch collator that does nothing.
62
+ """
63
+ return batch
64
+
65
+ class NestedTensor(object):
66
+ def __init__(self, tensors, mask: Optional[Tensor]):
67
+ self.tensors = tensors
68
+ self.mask = mask
69
+
70
+ def to(self, device):
71
+ # type: (Device) -> NestedTensor # noqa
72
+ cast_tensor = self.tensors.to(device)
73
+ mask = self.mask
74
+ if mask is not None:
75
+ assert mask is not None
76
+ cast_mask = mask.to(device)
77
+ else:
78
+ cast_mask = None
79
+ return NestedTensor(cast_tensor, cast_mask)
80
+
81
+ def decompose(self):
82
+ return self.tensors, self.mask
83
+
84
+ def __repr__(self):
85
+ return str(self.tensors)
86
+
87
+
88
+ def _max_by_axis(the_list):
89
+ # type: (List[List[int]]) -> List[int]
90
+ maxes = the_list[0]
91
+ for sublist in the_list[1:]:
92
+ for index, item in enumerate(sublist):
93
+ maxes[index] = max(maxes[index], item)
94
+ return maxes
95
+
96
+
97
+ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
98
+ # TODO make this more general
99
+ if tensor_list[0].ndim == 3:
100
+ if torchvision._is_tracing():
101
+ # nested_tensor_from_tensor_list() does not export well to ONNX
102
+ # call _onnx_nested_tensor_from_tensor_list() instead
103
+ return _onnx_nested_tensor_from_tensor_list(tensor_list)
104
+
105
+ # TODO make it support different-sized images
106
+ max_size = _max_by_axis([list(img.shape) for img in tensor_list])
107
+ # min_size = tuple(min(s) for s in zip(*[img.shape for img in tensor_list]))
108
+ batch_shape = [len(tensor_list)] + max_size
109
+ b, c, h, w = batch_shape
110
+ dtype = tensor_list[0].dtype
111
+ device = tensor_list[0].device
112
+ tensor = torch.zeros(batch_shape, dtype=dtype, device=device)
113
+ mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
114
+ for img, pad_img, m in zip(tensor_list, tensor, mask):
115
+ pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
116
+ m[: img.shape[1], : img.shape[2]] = False
117
+ else:
118
+ raise ValueError("not supported")
119
+ return NestedTensor(tensor, mask)
120
+
121
+
122
+ # _onnx_nested_tensor_from_tensor_list() is an implementation of
123
+ # nested_tensor_from_tensor_list() that is supported by ONNX tracing.
124
+ @torch.jit.unused
125
+ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
126
+ max_size = []
127
+ for i in range(tensor_list[0].dim()):
128
+ max_size_i = torch.max(
129
+ torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
130
+ ).to(torch.int64)
131
+ max_size.append(max_size_i)
132
+ max_size = tuple(max_size)
133
+
134
+ # work around for
135
+ # pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
136
+ # m[: img.shape[1], :img.shape[2]] = False
137
+ # which is not yet supported in onnx
138
+ padded_imgs = []
139
+ padded_masks = []
140
+ for img in tensor_list:
141
+ padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
142
+ padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
143
+ padded_imgs.append(padded_img)
144
+
145
+ m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
146
+ padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
147
+ padded_masks.append(padded_mask.to(torch.bool))
148
+
149
+ tensor = torch.stack(padded_imgs)
150
+ mask = torch.stack(padded_masks)
151
+
152
+ return NestedTensor(tensor, mask=mask)
153
+
154
+
155
+ def is_dist_avail_and_initialized():
156
+ if not dist.is_available():
157
+ return False
158
+ if not dist.is_initialized():
159
+ return False
160
+ return True
161
+
162
+
163
+ class SmoothedValue(object):
164
+ """Track a series of values and provide access to smoothed values over a
165
+ window or the global series average.
166
+ """
167
+
168
+ def __init__(self, window_size=20, fmt=None):
169
+ if fmt is None:
170
+ fmt = "{median:.4f} ({global_avg:.4f})"
171
+ self.deque = deque(maxlen=window_size)
172
+ self.total = 0.0
173
+ self.count = 0
174
+ self.fmt = fmt
175
+
176
+ def update(self, value, n=1):
177
+ self.deque.append(value)
178
+ self.count += n
179
+ self.total += value * n
180
+
181
+ def synchronize_between_processes(self):
182
+ """
183
+ Warning: does not synchronize the deque!
184
+ """
185
+ if not is_dist_avail_and_initialized():
186
+ return
187
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
188
+ dist.barrier()
189
+ dist.all_reduce(t)
190
+ t = t.tolist()
191
+ self.count = int(t[0])
192
+ self.total = t[1]
193
+
194
+ @property
195
+ def median(self):
196
+ d = torch.tensor(list(self.deque))
197
+ return d.median().item()
198
+
199
+ @property
200
+ def avg(self):
201
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
202
+ return d.mean().item()
203
+
204
+ @property
205
+ def global_avg(self):
206
+ return self.total / self.count
207
+
208
+ @property
209
+ def max(self):
210
+ return max(self.deque)
211
+
212
+ @property
213
+ def value(self):
214
+ return self.deque[-1]
215
+
216
+ def __str__(self):
217
+ return self.fmt.format(
218
+ median=self.median,
219
+ avg=self.avg,
220
+ global_avg=self.global_avg,
221
+ max=self.max,
222
+ value=self.value)
223
+
224
+
225
+ class MetricLogger(object):
226
+ def __init__(self, delimiter="\t"):
227
+ self.meters = defaultdict(SmoothedValue)
228
+ self.delimiter = delimiter
229
+
230
+ def update(self, **kwargs):
231
+ for k, v in kwargs.items():
232
+ if v is None:
233
+ continue
234
+ if isinstance(v, torch.Tensor):
235
+ v = v.item()
236
+ assert isinstance(v, (float, int))
237
+ self.meters[k].update(v)
238
+
239
+ def __getattr__(self, attr):
240
+ if attr in self.meters:
241
+ return self.meters[attr]
242
+ if attr in self.__dict__:
243
+ return self.__dict__[attr]
244
+ raise AttributeError("'{}' object has no attribute '{}'".format(
245
+ type(self).__name__, attr))
246
+
247
+ def __str__(self):
248
+ loss_str = []
249
+ for name, meter in self.meters.items():
250
+ loss_str.append(
251
+ "{}: {}".format(name, str(meter))
252
+ )
253
+ return self.delimiter.join(loss_str)
254
+
255
+ def synchronize_between_processes(self):
256
+ for meter in self.meters.values():
257
+ meter.synchronize_between_processes()
258
+
259
+ def add_meter(self, name, meter):
260
+ self.meters[name] = meter
261
+
262
+ def log_every(self, iterable, print_freq, header=None):
263
+ i = 0
264
+ if not header:
265
+ header = ''
266
+ start_time = time.time()
267
+ end = time.time()
268
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
269
+ data_time = SmoothedValue(fmt='{avg:.4f}')
270
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
271
+ log_msg = [
272
+ header,
273
+ '[{0' + space_fmt + '}/{1}]',
274
+ 'eta: {eta}',
275
+ '{meters}',
276
+ 'time: {time}',
277
+ 'data: {data}'
278
+ ]
279
+ if torch.cuda.is_available():
280
+ log_msg.append('max mem: {memory:.0f}')
281
+ log_msg = self.delimiter.join(log_msg)
282
+ MB = 1024.0 * 1024.0
283
+ for obj in iterable:
284
+ data_time.update(time.time() - end)
285
+ yield obj
286
+ iter_time.update(time.time() - end)
287
+ if i % print_freq == 0 or i == len(iterable) - 1:
288
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
289
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
290
+ if torch.cuda.is_available():
291
+ print(log_msg.format(
292
+ i, len(iterable), eta=eta_string,
293
+ meters=str(self),
294
+ time=str(iter_time), data=str(data_time),
295
+ memory=torch.cuda.max_memory_allocated() / MB))
296
+ else:
297
+ print(log_msg.format(
298
+ i, len(iterable), eta=eta_string,
299
+ meters=str(self),
300
+ time=str(iter_time), data=str(data_time)))
301
+ i += 1
302
+ end = time.time()
303
+ total_time = time.time() - start_time
304
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
305
+ print('{} Total time: {} ({:.4f} s / it)'.format(
306
+ header, total_time_str, total_time / len(iterable)))
307
+
308
+
309
+ def all_reduce_mean(x):
310
+ world_size = dist.get_world_size()
311
+ if world_size > 1:
312
+ x_reduce = torch.tensor(x).cuda()
313
+ dist.all_reduce(x_reduce)
314
+ x_reduce /= world_size
315
+ return x_reduce.item()
316
+ else:
317
+ return x
318
+
319
+
320
+ class NativeScaler:
321
+ state_dict_key = "amp_scaler"
322
+
323
+ def __init__(self):
324
+ self._scaler = torch.cuda.amp.GradScaler()
325
+
326
+ def __call__(self, loss, optimizer, clip_grad=3., parameters=None, create_graph=False, update_grad=True):
327
+ self._scaler.scale(loss).backward(create_graph=create_graph)
328
+ if update_grad:
329
+ if clip_grad is not None:
330
+ assert parameters is not None
331
+ self._scaler.unscale_(optimizer) # unscale the gradients of optimizer's assigned params in-place
332
+ norm = torch.nn.utils.clip_grad_norm_(parameters, clip_grad)
333
+ else:
334
+ self._scaler.unscale_(optimizer)
335
+ norm = get_grad_norm_(parameters)
336
+ self._scaler.step(optimizer)
337
+ self._scaler.update()
338
+ else:
339
+ norm = None
340
+ return norm
341
+
342
+ def state_dict(self):
343
+ return self._scaler.state_dict()
344
+
345
+ def load_state_dict(self, state_dict):
346
+ self._scaler.load_state_dict(state_dict)
347
+
348
+
349
+ def get_grad_norm_(parameters, norm_type: float = 2.0) -> torch.Tensor:
350
+ if isinstance(parameters, torch.Tensor):
351
+ parameters = [parameters]
352
+ parameters = [p for p in parameters if p.grad is not None and p.requires_grad]
353
+ norm_type = float(norm_type)
354
+ if len(parameters) == 0:
355
+ return torch.tensor(0.)
356
+ device = parameters[0].grad.device
357
+ total_norm = torch.norm(torch.stack([torch.norm(p.grad.detach(), norm_type).to(device) for p in parameters]),
358
+ norm_type)
359
+ return total_norm
360
+
361
+
362
+ def get_obj_from_str(string, reload=False):
363
+ module, cls = string.rsplit(".", 1)
364
+ if reload:
365
+ module_imp = importlib.import_module(module)
366
+ importlib.reload(module_imp)
367
+ return getattr(importlib.import_module(module, package=None), cls)
368
+
369
+
370
+ def instantiate_from_config(config):
371
+ if not "target" in config:
372
+ raise KeyError("Expected key `target` to instantiate.")
373
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
374
+
375
+
376
+ def save_on_master(*args, **kwargs):
377
+ if dist.get_rank() == 0:
378
+ torch.save(*args, **kwargs)
379
+
380
+
381
+ def save_model(args, epoch, model, model_without_ddp, optimizer_g, optimizer_d, loss_scaler):
382
+ output_dir = Path(args.output_dir)
383
+ epoch_name = str(epoch)
384
+ if loss_scaler is not None:
385
+ checkpoint_paths = [output_dir / ('checkpoint-%s.pth' % epoch_name)]
386
+ for checkpoint_path in checkpoint_paths:
387
+ to_save = {
388
+ 'model': model_without_ddp.state_dict(),
389
+ 'optimizer_g': optimizer_g.state_dict(),
390
+ 'optimizer_d': optimizer_d.state_dict(),
391
+ 'epoch': epoch,
392
+ 'scaler': loss_scaler.state_dict(),
393
+ 'args': args,
394
+ }
395
+
396
+ save_on_master(to_save, checkpoint_path)
397
+ else:
398
+ client_state = {'epoch': epoch}
399
+ model.save_checkpoint(save_dir=args.output_dir, tag="checkpoint-%s" % epoch_name, client_state=client_state)
400
+
401
+
402
+ def load_model(args, model_without_ddp, optimizer_g, optimizer_d, loss_scaler):
403
+ if args.resume:
404
+ if args.resume.startswith('https'):
405
+ checkpoint = torch.hub.load_state_dict_from_url(
406
+ args.resume, map_location='cpu', check_hash=True)
407
+ else:
408
+ checkpoint = torch.load(args.resume, map_location='cpu')
409
+ model_without_ddp.load_state_dict(checkpoint['model'])
410
+ print("Resume checkpoint %s" % args.resume)
411
+ if 'optimizer_g' in checkpoint and 'epoch' in checkpoint and not (hasattr(args, 'eval') and args.eval):
412
+ optimizer_g.load_state_dict(checkpoint['optimizer_g'])
413
+ optimizer_d.load_state_dict(checkpoint['optimizer_d'])
414
+ args.start_epoch = checkpoint['epoch'] + 1
415
+ if 'scaler' in checkpoint:
416
+ loss_scaler.load_state_dict(checkpoint['scaler'])
417
+ print("With optim & sched!")
418
+
419
+
420
+ def download(url, local_path, chunk_size=1024):
421
+ os.makedirs(os.path.split(local_path)[0], exist_ok=True)
422
+ with requests.get(url, stream=True) as r:
423
+ total_size = int(r.headers.get("content-length", 0))
424
+ with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
425
+ with open(local_path, "wb") as f:
426
+ for data in r.iter_content(chunk_size=chunk_size):
427
+ if data:
428
+ f.write(data)
429
+ pbar.update(chunk_size)
430
+
431
+
432
+ def md5_hash(path):
433
+ with open(path, "rb") as f:
434
+ content = f.read()
435
+ return hashlib.md5(content).hexdigest()
436
+
437
+
438
+ def get_ckpt_path(name, root, check=False):
439
+ assert name in URL_MAP
440
+ path = os.path.join(root, CKPT_MAP[name])
441
+ if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
442
+ print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
443
+ download(URL_MAP[name], path)
444
+ md5 = md5_hash(path)
445
+ assert md5 == MD5_MAP[name], md5
446
+ return path
447
+
448
+
449
+ class KeyNotFoundError(Exception):
450
+ def __init__(self, cause, keys=None, visited=None):
451
+ self.cause = cause
452
+ self.keys = keys
453
+ self.visited = visited
454
+ messages = list()
455
+ if keys is not None:
456
+ messages.append("Key not found: {}".format(keys))
457
+ if visited is not None:
458
+ messages.append("Visited: {}".format(visited))
459
+ messages.append("Cause:\n{}".format(cause))
460
+ message = "\n".join(messages)
461
+ super().__init__(message)
462
+
463
+
464
+ def retrieve(
465
+ list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
466
+ ):
467
+ """Given a nested list or dict return the desired value at key expanding
468
+ callable nodes if necessary and :attr:`expand` is ``True``. The expansion
469
+ is done in-place.
470
+
471
+ Parameters
472
+ ----------
473
+ list_or_dict : list or dict
474
+ Possibly nested list or dictionary.
475
+ key : str
476
+ key/to/value, path like string describing all keys necessary to
477
+ consider to get to the desired value. List indices can also be
478
+ passed here.
479
+ splitval : str
480
+ String that defines the delimiter between keys of the
481
+ different depth levels in `key`.
482
+ default : obj
483
+ Value returned if :attr:`key` is not found.
484
+ expand : bool
485
+ Whether to expand callable nodes on the path or not.
486
+
487
+ Returns
488
+ -------
489
+ The desired value or if :attr:`default` is not ``None`` and the
490
+ :attr:`key` is not found returns ``default``.
491
+
492
+ Raises
493
+ ------
494
+ Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
495
+ ``None``.
496
+ """
497
+
498
+ keys = key.split(splitval)
499
+
500
+ success = True
501
+ try:
502
+ visited = []
503
+ parent = None
504
+ last_key = None
505
+ for key in keys:
506
+ if callable(list_or_dict):
507
+ if not expand:
508
+ raise KeyNotFoundError(
509
+ ValueError(
510
+ "Trying to get past callable node with expand=False."
511
+ ),
512
+ keys=keys,
513
+ visited=visited,
514
+ )
515
+ list_or_dict = list_or_dict()
516
+ parent[last_key] = list_or_dict
517
+
518
+ last_key = key
519
+ parent = list_or_dict
520
+
521
+ try:
522
+ if isinstance(list_or_dict, dict):
523
+ list_or_dict = list_or_dict[key]
524
+ else:
525
+ list_or_dict = list_or_dict[int(key)]
526
+ except (KeyError, IndexError, ValueError) as e:
527
+ raise KeyNotFoundError(e, keys=keys, visited=visited)
528
+
529
+ visited += [key]
530
+ # final expansion of retrieved value
531
+ if expand and callable(list_or_dict):
532
+ list_or_dict = list_or_dict()
533
+ parent[last_key] = list_or_dict
534
+ except KeyNotFoundError as e:
535
+ if default is None:
536
+ raise e
537
+ else:
538
+ list_or_dict = default
539
+ success = False
540
+
541
+ if not pass_success:
542
+ return list_or_dict
543
+ else:
544
+ return list_or_dict, success
545
+
546
+
547
+ if __name__ == "__main__":
548
+ config = {"keya": "a",
549
+ "keyb": "b",
550
+ "keyc":
551
+ {"cc1": 1,
552
+ "cc2": 2,
553
+ }
554
+ }
555
+ from omegaconf import OmegaConf
556
+
557
+ config = OmegaConf.create(config)
558
+ print(config)
559
+ retrieve(config, "keya")
560
+
561
+ def instantiate_from_config(config):
562
+
563
+ if not "target" in config:
564
+ raise KeyError("Expected key `target` to instantiate.")
565
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
566
+
567
+ def get_obj_from_str(string, reload=False):
568
+ module, cls = string.rsplit(".", 1)
569
+ if reload:
570
+ module_imp = importlib.import_module(module)
571
+ importlib.reload(module_imp)
572
+ return getattr(importlib.import_module(module, package=None), cls)
paintmind/stage1/__init__.py ADDED
File without changes
paintmind/stage1/diffuse_slot.py ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ import numpy as np
4
+ import torch.nn as nn
5
+
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torch
9
+
10
+ from paintmind.stage1.diffusion import create_diffusion
11
+ from paintmind.stage1.diffusion_transfomers import DiT
12
+ from paintmind.stage1.quantize import DiagonalGaussianDistribution
13
+ from paintmind.stage1.transport import create_transport, Sampler
14
+
15
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
16
+ from transformers import SiglipVisionModel, CLIPVisionModel
17
+
18
+ CLIP_DEFAULT_MEAN = (0.48145466, 0.4578275, 0.40821073)
19
+ CLIP_DEFAULT_STD = (0.26862954, 0.26130258, 0.27577711)
20
+ SIGLIP_DEFAULT_MEAN = (0.5, 0.5, 0.5)
21
+ SIGLIP_DEFAULT_STD = (0.5, 0.5, 0.5)
22
+
23
+ def build_mlp(hidden_size, projector_dim, z_dim):
24
+ return nn.Sequential(
25
+ nn.Linear(hidden_size, projector_dim),
26
+ nn.SiLU(),
27
+ nn.Linear(projector_dim, projector_dim),
28
+ nn.SiLU(),
29
+ nn.Linear(projector_dim, z_dim),
30
+ )
31
+
32
+ class DiT_with_autoenc_cond(DiT):
33
+ def __init__(
34
+ self,
35
+ *args,
36
+ num_autoenc=32,
37
+ autoenc_dim=4,
38
+ cond_method="adaln",
39
+ mask_type="simple",
40
+ class_cond=False,
41
+ use_repa=False,
42
+ z_dim=768,
43
+ encoder_depth=8,
44
+ projector_dim=2048,
45
+ **kwargs,
46
+ ):
47
+ super().__init__(*args, **kwargs)
48
+ self.autoenc_dim = autoenc_dim
49
+ self.class_cond = class_cond
50
+ self.mask_type = mask_type
51
+ self.hidden_size = kwargs["hidden_size"]
52
+ self.cond_drop_prob = self.y_embedder.dropout_prob # 0.1 without cond guidance
53
+ self.null_cond = nn.Parameter(torch.zeros(1, num_autoenc, autoenc_dim))
54
+ torch.nn.init.normal_(self.null_cond, std=.02)
55
+ # NOTE: adaln is adaptive layer normalization, token fed the cond to the attention layer
56
+ assert cond_method in [
57
+ "adaln",
58
+ "token",
59
+ "token+adaln",
60
+ ], f"Invalid cond_method: {cond_method}"
61
+ self.cond_method = cond_method
62
+ if "token" in cond_method:
63
+ self.autoenc_cond_embedder = nn.Linear(autoenc_dim, self.hidden_size)
64
+ else:
65
+ self.autoenc_cond_embedder = nn.Linear(
66
+ num_autoenc * autoenc_dim, self.hidden_size
67
+ )
68
+
69
+ if cond_method == "token+adaln":
70
+ self.autoenc_proj_ln = nn.Linear(self.hidden_size, self.hidden_size)
71
+
72
+ if not class_cond:
73
+ self.y_embedder = nn.Identity()
74
+
75
+ self.use_repa = use_repa
76
+ self._repa_hook = None
77
+ self.encoder_depth = encoder_depth
78
+ if use_repa:
79
+ self.projector = build_mlp(self.hidden_size, projector_dim, z_dim)
80
+
81
+ def embed_cond(self, autoenc_cond, drop_mask=None):
82
+ # autoenc_cond: (N, K, D)
83
+ # drop_ids: (N)
84
+ # self.null_cond: (1, K, D)
85
+ # NOTE: this dropout will replace some condition from the autoencoder to null condition
86
+ # this is to enable classifier-free guidance.
87
+ batch_size = autoenc_cond.shape[0]
88
+ if drop_mask is None:
89
+ # randomly drop all conditions, for classifier-free guidance
90
+ if self.training:
91
+ drop_ids = (
92
+ torch.rand(batch_size, 1, 1, device=autoenc_cond.device)
93
+ < self.cond_drop_prob
94
+ )
95
+ autoenc_cond_drop = torch.where(drop_ids, self.null_cond, autoenc_cond)
96
+ else:
97
+ autoenc_cond_drop = autoenc_cond
98
+ else:
99
+ # randomly drop some conditions according to the drop_mask (N, K)
100
+ # True means keep
101
+ autoenc_cond_drop = torch.where(drop_mask[:, :, None], autoenc_cond, self.null_cond)
102
+ if "token" in self.cond_method:
103
+ return self.autoenc_cond_embedder(autoenc_cond_drop)
104
+ return self.autoenc_cond_embedder(autoenc_cond_drop.reshape(batch_size, -1))
105
+
106
+ # def forward(self, x, t, y, autoenc_cond):
107
+ def forward(self, x, t, autoenc_cond, drop_mask=None, y=None):
108
+ """
109
+ Forward pass of DiT.
110
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
111
+ t: (N,) tensor of diffusion timesteps
112
+ y: (N,) tensor of class labels
113
+ autoenc_cond: (N, K, D) tensor of autoencoder conditions (slots)
114
+ """
115
+ x = (
116
+ self.x_embedder(x) + self.pos_embed
117
+ ) # (N, T, D), where T = H * W / patch_size ** 2
118
+ N, T, D = x.shape
119
+
120
+ c = self.t_embedder(t) # (N, D)
121
+ if y is not None and self.class_cond:
122
+ y = self.y_embedder(y, self.training) # (N, D)
123
+ c = c + y # (N, D)
124
+
125
+ if self.mask_type == "replace":
126
+ autoenc = self.embed_cond(autoenc_cond, drop_mask)
127
+ else:
128
+ autoenc = self.embed_cond(autoenc_cond)
129
+
130
+ if self.cond_method == "adaln":
131
+ c = c + autoenc # add the encoder condition to adaln
132
+ elif self.cond_method == "token":
133
+ num_tokens = x.shape[1]
134
+ # append the autoencoder condition to the token sequence
135
+ x = torch.cat((x, autoenc), dim=1)
136
+ elif self.cond_method == "token+adaln":
137
+ c = c + self.autoenc_proj_ln(autoenc.mean(dim=1))
138
+ num_tokens = x.shape[1]
139
+ x = torch.cat((x, autoenc), dim=1)
140
+ else:
141
+ raise ValueError(f"Invalid cond_method: {self.cond_method}")
142
+
143
+ for i, block in enumerate(self.blocks):
144
+ if self.mask_type == "replace":
145
+ x = block(x, c) # (N, T, D)
146
+ else:
147
+ x = block(x, c, drop_mask) # (N, T, D)
148
+ if (i + 1) == self.encoder_depth and self.use_repa:
149
+ projected = self.projector(x)
150
+ self._repa_hook = projected[:, :num_tokens]
151
+
152
+ if "token" in self.cond_method:
153
+ x = x[:, :num_tokens]
154
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
155
+ x = self.unpatchify(x) # (N, out_channels, H, W)
156
+ return x
157
+
158
+ def forward_with_cfg(self, x, t, autoenc_cond, drop_mask, y=None, cfg_scale=1.0):
159
+ """
160
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
161
+ """
162
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
163
+ half = x[: len(x) // 2]
164
+ combined = torch.cat([half, half], dim=0)
165
+ model_out = self.forward(combined, t, autoenc_cond, drop_mask, y)
166
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
167
+ # three channels by default. The standard approach to cfg applies it to all channels.
168
+ # This can be done by uncommenting the following line and commenting-out the line following that.
169
+ eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
170
+ # eps, rest = model_out[:, :3], model_out[:, 3:]
171
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
172
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
173
+ eps = torch.cat([half_eps, half_eps], dim=0)
174
+ return torch.cat([eps, rest], dim=1)
175
+
176
+ #################################################################################
177
+ # DiT Configs #
178
+ #################################################################################
179
+
180
+
181
+ def DiT_with_autoenc_cond_XL_2(**kwargs):
182
+ return DiT_with_autoenc_cond(
183
+ depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs
184
+ )
185
+
186
+
187
+ def DiT_with_autoenc_cond_XL_4(**kwargs):
188
+ return DiT_with_autoenc_cond(
189
+ depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs
190
+ )
191
+
192
+
193
+ def DiT_with_autoenc_cond_XL_8(**kwargs):
194
+ return DiT_with_autoenc_cond(
195
+ depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs
196
+ )
197
+
198
+
199
+ def DiT_with_autoenc_cond_L_2(**kwargs):
200
+ return DiT_with_autoenc_cond(
201
+ depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs
202
+ )
203
+
204
+
205
+ def DiT_with_autoenc_cond_L_4(**kwargs):
206
+ return DiT_with_autoenc_cond(
207
+ depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs
208
+ )
209
+
210
+
211
+ def DiT_with_autoenc_cond_L_8(**kwargs):
212
+ return DiT_with_autoenc_cond(
213
+ depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs
214
+ )
215
+
216
+
217
+ def DiT_with_autoenc_cond_B_2(**kwargs):
218
+ return DiT_with_autoenc_cond(
219
+ depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs
220
+ )
221
+
222
+
223
+ def DiT_with_autoenc_cond_B_4(**kwargs):
224
+ return DiT_with_autoenc_cond(
225
+ depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs
226
+ )
227
+
228
+
229
+ def DiT_with_autoenc_cond_B_8(**kwargs):
230
+ return DiT_with_autoenc_cond(
231
+ depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs
232
+ )
233
+
234
+
235
+ def DiT_with_autoenc_cond_S_2(**kwargs):
236
+ return DiT_with_autoenc_cond(
237
+ depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs
238
+ )
239
+
240
+
241
+ def DiT_with_autoenc_cond_S_4(**kwargs):
242
+ return DiT_with_autoenc_cond(
243
+ depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs
244
+ )
245
+
246
+
247
+ def DiT_with_autoenc_cond_S_8(**kwargs):
248
+ return DiT_with_autoenc_cond(
249
+ depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs
250
+ )
251
+
252
+
253
+ DiT_with_autoenc_cond_models = {
254
+ "DiT-XL-2": DiT_with_autoenc_cond_XL_2,
255
+ "DiT-XL-4": DiT_with_autoenc_cond_XL_4,
256
+ "DiT-XL-8": DiT_with_autoenc_cond_XL_8,
257
+ "DiT-L-2": DiT_with_autoenc_cond_L_2,
258
+ "DiT-L-4": DiT_with_autoenc_cond_L_4,
259
+ "DiT-L-8": DiT_with_autoenc_cond_L_8,
260
+ "DiT-B-2": DiT_with_autoenc_cond_B_2,
261
+ "DiT-B-4": DiT_with_autoenc_cond_B_4,
262
+ "DiT-B-8": DiT_with_autoenc_cond_B_8,
263
+ "DiT-S-2": DiT_with_autoenc_cond_S_2,
264
+ "DiT-S-4": DiT_with_autoenc_cond_S_4,
265
+ "DiT-S-8": DiT_with_autoenc_cond_S_8,
266
+ }
267
+
268
+ from torch.distributions import Geometric
269
+
270
+ class NestedSampler(nn.Module):
271
+ def __init__(
272
+ self,
273
+ num_slots,
274
+ rho=0.03,
275
+ nest_dist="geometric",
276
+ mask_type="simple",
277
+ null_prob=0.1,
278
+ allow_zero=False,
279
+ one_slot_before=0, # Use only one slot before this epoch
280
+ ):
281
+ super().__init__()
282
+ self.num_slots = num_slots
283
+ self.mask_type = mask_type
284
+ self.rho = rho
285
+ self.geometric = Geometric(rho)
286
+ self.nest_dist = nest_dist
287
+ self.null_prob = null_prob
288
+ self.allow_zero = allow_zero
289
+ self.one_slot_before = one_slot_before
290
+ self.register_buffer("arange", torch.arange(num_slots))
291
+
292
+ def _apply_epoch_constraint(self, samples, epoch=None):
293
+ if epoch is not None and epoch < self.one_slot_before:
294
+ return torch.ones_like(samples)
295
+ return samples
296
+
297
+ def _apply_null_prob(self, samples, num):
298
+ # First determine which samples will be 0 based on null_prob
299
+ null_mask = torch.rand(num, device=samples.device) < self.null_prob
300
+ # Replace with 0 where null_mask is True
301
+ return torch.where(null_mask, torch.zeros_like(samples), samples)
302
+
303
+ def geometric_sample(self, num):
304
+ return self.geometric.sample([num]) + int(not self.allow_zero)
305
+
306
+ def uniform_sample(self, num):
307
+ return torch.randint(int(not self.allow_zero), self.num_slots + 1, (num,))
308
+
309
+ def power2_uniform_sample(self, num):
310
+ # Get powers of 2 up to num_slots and add num_slots
311
+ choices = [2**i for i in range(int(math.log2(self.num_slots)) + 1)]
312
+ if self.num_slots not in choices:
313
+ choices.append(self.num_slots)
314
+ return torch.tensor(choices)[torch.randint(0, len(choices), (num,))]
315
+
316
+ def sample(self, num, epoch=None):
317
+ if self.nest_dist == "geometric":
318
+ samples = self.geometric_sample(num)
319
+ elif self.nest_dist == "uniform":
320
+ samples = self.uniform_sample(num)
321
+ elif self.nest_dist == "power2_uniform":
322
+ samples = self.power2_uniform_sample(num)
323
+ else:
324
+ raise ValueError(f"Invalid nest_dist: {self.nest_dist}")
325
+ samples = self._apply_epoch_constraint(samples, epoch)
326
+ return self._apply_null_prob(samples, num)
327
+
328
+ def forward(self, batch_size, num_patches, device, inference_with_n_slots=-1, coupled_value=None, epoch=None):
329
+ if self.training:
330
+ if coupled_value is None:
331
+ b = self.sample(batch_size, epoch).to(device)
332
+ else:
333
+ b = coupled_value.long().to(device)
334
+ else:
335
+ if inference_with_n_slots != -1:
336
+ b = torch.full((batch_size,), inference_with_n_slots, device=device)
337
+ else:
338
+ b = torch.full((batch_size,), self.num_slots, device=device)
339
+ b = torch.clamp(b, max=self.num_slots)
340
+
341
+ slot_mask = self.arange[None, :] < b[:, None] # (batch_size, num_slots)
342
+ if self.mask_type == "replace":
343
+ return slot_mask
344
+ else:
345
+ return self.get_cond_attn_mask(slot_mask.unsqueeze(1), num_patches, self.num_slots, device)
346
+
347
+ def get_cond_attn_mask(self, slot_mask, num_patches, num_slots, device):
348
+ num_tokens = num_patches + num_slots
349
+ batch_size = slot_mask.shape[0]
350
+ if self.mask_type == "simple":
351
+ attn_mask = torch.ones((batch_size, num_tokens, num_tokens), dtype=torch.bool, device=device)
352
+ attn_mask[:, :, num_patches:] = slot_mask.expand(-1, num_tokens, -1)
353
+ elif self.mask_type == "causal":
354
+ attn_mask = torch.zeros((batch_size, num_tokens, num_tokens), dtype=torch.bool, device=device)
355
+ # 1) patches can see each other
356
+ attn_mask[:, :num_patches, :num_patches] = True
357
+ # 2) pathes can not see the last few slots
358
+ slot_mask = slot_mask.expand(-1, num_patches, -1)
359
+ attn_mask[:, :num_patches, num_patches:] = slot_mask
360
+ # 3) remaining slots are causal to each other
361
+ causal_mask = torch.ones((num_slots, num_slots), dtype=torch.bool, device=device).tril(diagonal=0)
362
+ attn_mask[:, num_patches:, num_patches:] = causal_mask
363
+ # 4) only the first slot can see the patches
364
+ attn_mask[:, num_patches + 1:, :num_patches] = False
365
+ else:
366
+ raise NotImplementedError(f"Invalid mask_type: {self.mask_type}")
367
+ return attn_mask.unsqueeze(1) # (batch_size, 1, num_tokens, num_tokens)
368
+
369
+ class DiffuseSlot(nn.Module):
370
+ def __init__(
371
+ self,
372
+ encoder="vit_base_patch16",
373
+ drop_path_rate=0.1,
374
+ enc_img_size=256,
375
+ enc_causal=True,
376
+ enc_use_mlp=False,
377
+ enc_hidden_dim=4096,
378
+ num_slots=16,
379
+ slot_dim=256,
380
+ slot_through=True,
381
+ norm_slots=False,
382
+ use_kl_loss=False,
383
+ kl_loss_weight=1e-6,
384
+ enable_nest=False,
385
+ enable_nest_after=-1,
386
+ nest_dist="geometric",
387
+ nest_rho=0.03,
388
+ nest_null_prob=0.1,
389
+ nest_allow_zero=False,
390
+ nest_one_slot_before=0,
391
+ coupled_sampling=False,
392
+ coupled_rho=-0.8,
393
+ dit_class_cond=False,
394
+ dit_mask_type="simple",
395
+ cond_method="adaln",
396
+ dit_model="DiT-B-4",
397
+ vae="stabilityai/sd-vae-ft-ema",
398
+ vae_path="pretrained_models/kl16.ckpt",
399
+ pretrained_dit=None,
400
+ pretrained_encoder=None,
401
+ freeze_dit=False,
402
+ freeze_vit_after=-1,
403
+ num_sampling_steps="ddim25",
404
+ ckpt_path=None,
405
+ ema_path=None,
406
+ use_repa=False,
407
+ repa_encoder="dinov2_vitb14",
408
+ repa_encoder_depth=8,
409
+ repa_loss_weight=1.0,
410
+ use_sit=False,
411
+ **kwargs,
412
+ ):
413
+ super().__init__()
414
+
415
+ z_dim = 0
416
+ self.use_repa = use_repa
417
+ self.repa_encoder_name = repa_encoder
418
+ self.repa_loss_weight = repa_loss_weight
419
+ self.use_sit = use_sit
420
+ if use_repa:
421
+ if "dinov2" in repa_encoder:
422
+ if "vitb" in repa_encoder or "vit_b" in repa_encoder or "vit-b" in repa_encoder:
423
+ self.repa_encoder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
424
+ elif "vitl" in repa_encoder or "vit_l" in repa_encoder or "vit-l" in repa_encoder:
425
+ self.repa_encoder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitl14')
426
+ else:
427
+ raise ValueError(f"Invalid repa_encoder: {repa_encoder}")
428
+ self.repa_encoder.image_size = 224
429
+ elif "clip" in repa_encoder:
430
+ if "vitb" in repa_encoder or "vit_b" in repa_encoder or "vit-b" in repa_encoder or "vit-base" in repa_encoder:
431
+ self.repa_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-base-patch16")
432
+ elif "vitl" in repa_encoder or "vit_l" in repa_encoder or "vit-l" in repa_encoder or "vit-large" in repa_encoder:
433
+ self.repa_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14")
434
+ else:
435
+ raise ValueError(f"Invalid repa_encoder: {repa_encoder}")
436
+ self.repa_encoder.embed_dim = self.repa_encoder.config.hidden_size
437
+ self.repa_encoder.image_size = self.repa_encoder.config.image_size
438
+ elif "siglip2" in repa_encoder:
439
+ if "vitb" in repa_encoder or "vit_b" in repa_encoder or "vit-b" in repa_encoder or "vit-base" in repa_encoder:
440
+ self.repa_encoder = SiglipVisionModel.from_pretrained("google/siglip2-base-patch16-256")
441
+ elif "vitl" in repa_encoder or "vit_l" in repa_encoder or "vit-l" in repa_encoder or "vit-large" in repa_encoder:
442
+ self.repa_encoder = SiglipVisionModel.from_pretrained("google/siglip2-large-patch16-256")
443
+ else:
444
+ raise ValueError(f"Invalid repa_encoder: {repa_encoder}")
445
+ self.repa_encoder.embed_dim = self.repa_encoder.config.hidden_size
446
+ self.repa_encoder.image_size = self.repa_encoder.config.image_size
447
+ elif "siglip" in repa_encoder:
448
+ if "vitb" in repa_encoder or "vit_b" in repa_encoder or "vit-b" in repa_encoder or "vit-base" in repa_encoder:
449
+ self.repa_encoder = SiglipVisionModel.from_pretrained("google/siglip-base-patch16-256")
450
+ elif "vitl" in repa_encoder or "vit_l" in repa_encoder or "vit-l" in repa_encoder or "vit-large" in repa_encoder:
451
+ self.repa_encoder = SiglipVisionModel.from_pretrained("google/siglip-large-patch16-256")
452
+ else:
453
+ raise ValueError(f"Invalid repa_encoder: {repa_encoder}")
454
+ self.repa_encoder.embed_dim = self.repa_encoder.config.hidden_size
455
+ self.repa_encoder.image_size = self.repa_encoder.config.image_size
456
+ else:
457
+ raise ValueError(f"Invalid repa_encoder: {repa_encoder}")
458
+ for param in self.repa_encoder.parameters():
459
+ param.requires_grad = False
460
+ self.repa_encoder.eval()
461
+ z_dim = self.repa_encoder.embed_dim
462
+
463
+ # DiT part
464
+ if not use_sit:
465
+ self.diffusion = create_diffusion(timestep_respacing="")
466
+ self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps)
467
+ else:
468
+ self.transport = create_transport()
469
+ self.sampler = Sampler(self.transport)
470
+ self.dit_input_size = enc_img_size // 8 if not "mar" in vae else enc_img_size // 16
471
+ self.dit_in_channels = 4 if not "mar" in vae else 16
472
+ self.dit = DiT_with_autoenc_cond_models[dit_model](
473
+ input_size=self.dit_input_size,
474
+ in_channels=self.dit_in_channels,
475
+ num_autoenc=num_slots,
476
+ autoenc_dim=slot_dim,
477
+ cond_method=cond_method,
478
+ class_cond=dit_class_cond,
479
+ mask_type=dit_mask_type,
480
+ use_repa=use_repa,
481
+ encoder_depth=repa_encoder_depth,
482
+ z_dim=z_dim,
483
+ learn_sigma=not use_sit,
484
+ )
485
+ self.dit_patch_size = self.dit.x_embedder.patch_size[0]
486
+ self.dit_num_patches = (self.dit_input_size // self.dit_patch_size) ** 2
487
+ self.pretrained_dit = pretrained_dit
488
+ if pretrained_dit is not None:
489
+ # now we load some pretrained model
490
+ dit_ckpt = torch.load(pretrained_dit, map_location="cpu")
491
+ msg = self.dit.load_state_dict(dit_ckpt, strict=False)
492
+ print("Load DiT from ckpt")
493
+ print(msg)
494
+ self.freeze_dit = freeze_dit
495
+ if freeze_dit:
496
+ assert pretrained_dit is not None, "pretrained_dit must be provided"
497
+ for param in self.dit.parameters():
498
+ param.requires_grad = False
499
+
500
+ if "mar" in vae:
501
+ from diffusers import AutoencoderKL
502
+ self.vae = AutoencoderKL.from_pretrained("xwen99/mar-vae-kl16")
503
+ self.scaling_factor = 0.2325
504
+ elif vae == "openai/consistency-decoder":
505
+ from diffusers import ConsistencyDecoderVAE
506
+ self.vae = ConsistencyDecoderVAE.from_pretrained(vae)
507
+ self.scaling_factor = 0.18215
508
+ else: # eg, "stabilityai/sd-vae-ft-ema"
509
+ from diffusers import AutoencoderKL
510
+ self.vae = AutoencoderKL.from_pretrained(vae)
511
+ self.scaling_factor = 0.18215
512
+
513
+ self.vae.eval().requires_grad_(False)
514
+
515
+ # image encoder part
516
+ import paintmind.stage1.vision_transformers as vision_transformer
517
+
518
+ self.enc_img_size = enc_img_size
519
+ self.enc_causal = enc_causal
520
+ encoder_fn = vision_transformer.__dict__[encoder]
521
+
522
+ self.encoder = encoder_fn(
523
+ img_size=[enc_img_size],
524
+ num_slots=num_slots,
525
+ slot_through=slot_through,
526
+ drop_path_rate=drop_path_rate,
527
+ )
528
+ self.num_slots = num_slots
529
+ self.norm_slots = norm_slots
530
+ self.use_kl_loss = use_kl_loss
531
+ self.kl_loss_weight = kl_loss_weight
532
+ self.num_channels = self.encoder.num_features
533
+ self.pretrained_encoder = pretrained_encoder
534
+ if pretrained_encoder is not None:
535
+ # __import__("ipdb").set_trace()
536
+ encoder_ckpt = torch.load(pretrained_encoder, map_location="cpu")
537
+ # drop pos_embed from ckpt
538
+ encoder_ckpt = {
539
+ k.replace("blocks.", "blocks.0."): v
540
+ for k, v in encoder_ckpt.items()
541
+ if not k.startswith("pos_embed")
542
+ }
543
+ msg = self.encoder.load_state_dict(encoder_ckpt, strict=False)
544
+ print("Load encoder from ckpt")
545
+ print(msg)
546
+
547
+ if not enc_use_mlp:
548
+ self.encoder2slot = nn.Linear(self.num_channels, slot_dim * 2 if self.use_kl_loss else slot_dim)
549
+ else:
550
+ self.encoder2slot = nn.Sequential(
551
+ nn.Linear(self.num_channels, enc_hidden_dim),
552
+ nn.GELU(),
553
+ nn.Linear(enc_hidden_dim, slot_dim * 2 if self.use_kl_loss else slot_dim),
554
+ )
555
+
556
+ self.nested_sampler = NestedSampler(
557
+ num_slots,
558
+ rho=nest_rho,
559
+ nest_dist=nest_dist,
560
+ mask_type=dit_mask_type,
561
+ null_prob=nest_null_prob,
562
+ allow_zero=nest_allow_zero,
563
+ one_slot_before=nest_one_slot_before,
564
+ )
565
+ self.nest_allow_zero = nest_allow_zero
566
+ self.nest_rho = nest_rho
567
+ self.use_coupled_sampling = coupled_sampling
568
+ self.couple_sampling_rho = coupled_rho
569
+ self.enable_nest = enable_nest
570
+ self.enable_nest_after = enable_nest_after
571
+ self.freeze_vit_after = freeze_vit_after
572
+ self.current_epoch = 0
573
+
574
+ def coupled_sampling(self, timestamps):
575
+ """
576
+ Convert timestamps to coupled num_slots values where higher timestamps
577
+ tend to produce lower num_slots values.
578
+
579
+ Args:
580
+ timestamps: Tensor of shape (batch_size,) with values in [0, 1000)
581
+
582
+ Returns:
583
+ Tensor of shape (batch_size,) with values in [1, num_slots + 1)
584
+ """
585
+ # Normalize timestamps to [0, 1]
586
+ t_normalized = 1 - (timestamps.float() / timestamps.max())
587
+ # Scale to [1, num_slots + 1) and round to integers
588
+ adder = int(not self.nest_allow_zero)
589
+ scaled = adder + t_normalized * (self.num_slots + 1 - adder)
590
+ num_slots2use = scaled.long().clamp(adder, self.num_slots)
591
+ return num_slots2use
592
+
593
+ @torch.no_grad()
594
+ def vae_encode(self, x):
595
+ x = x * 2 - 1
596
+ x = self.vae.encode(x)
597
+ if hasattr(x, 'latent_dist'):
598
+ x = x.latent_dist
599
+ return x.sample().mul_(self.scaling_factor)
600
+
601
+ @torch.no_grad()
602
+ def vae_decode(self, z):
603
+ z = self.vae.decode(z / self.scaling_factor)
604
+ if hasattr(z, 'sample'):
605
+ z = z.sample
606
+ return (z + 1) / 2
607
+
608
+ @torch.no_grad()
609
+ def repa_encode(self, x):
610
+ if "dinov2" in self.repa_encoder_name:
611
+ mean = torch.Tensor(IMAGENET_DEFAULT_MEAN).to(x.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
612
+ std = torch.Tensor(IMAGENET_DEFAULT_STD).to(x.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
613
+ elif "clip" in self.repa_encoder_name:
614
+ mean = torch.Tensor(CLIP_DEFAULT_MEAN).to(x.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
615
+ std = torch.Tensor(CLIP_DEFAULT_STD).to(x.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
616
+ elif "siglip" in self.repa_encoder_name:
617
+ mean = torch.Tensor(SIGLIP_DEFAULT_MEAN).to(x.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
618
+ std = torch.Tensor(SIGLIP_DEFAULT_STD).to(x.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
619
+ else:
620
+ raise ValueError(f"Invalid repa_encoder: {self.repa_encoder_name}")
621
+ x = (x - mean) / std
622
+ if self.repa_encoder.image_size != self.enc_img_size:
623
+ x = torch.nn.functional.interpolate(x, self.repa_encoder.image_size, mode='bicubic')
624
+ if "dinov2" in self.repa_encoder_name:
625
+ x = self.repa_encoder.forward_features(x)['x_norm_patchtokens']
626
+ else:
627
+ x = self.repa_encoder(x)["last_hidden_state"]
628
+ return x
629
+
630
+ def encode_slots(self, x):
631
+ if self.pretrained_encoder is not None:
632
+ x = F.interpolate(x, size=224, mode='bicubic')
633
+ slots = self.encoder(x, is_causal=self.enc_causal)
634
+ slots = self.encoder2slot(slots)
635
+ if self.norm_slots:
636
+ if not self.use_kl_loss:
637
+ slots_std = torch.std(slots, dim=-1, keepdim=True)
638
+ slots_mean = torch.mean(slots, dim=-1, keepdim=True)
639
+ slots = (slots - slots_mean) / slots_std # this works better than kl loss
640
+ else:
641
+ slots = DiagonalGaussianDistribution(slots)
642
+ return slots
643
+
644
+ def forward_with_latents(self,
645
+ x_vae,
646
+ slots,
647
+ z,
648
+ targets=None,
649
+ sample=False,
650
+ epoch=None,
651
+ inference_with_n_slots=-1,
652
+ cfg=1.0):
653
+ losses = {}
654
+ batch_size = x_vae.shape[0]
655
+ num_patches = self.dit_num_patches
656
+ device = x_vae.device
657
+
658
+ if (
659
+ epoch is not None
660
+ and epoch >= self.enable_nest_after
661
+ and self.enable_nest_after != -1
662
+ ):
663
+ self.enable_nest = True
664
+
665
+ t = torch.randint(0, 1000, (x_vae.shape[0],), device=device)
666
+
667
+ if self.enable_nest or inference_with_n_slots != -1:
668
+ if self.use_coupled_sampling:
669
+ num_slots2use = self.coupled_sampling(t)
670
+ else:
671
+ num_slots2use = None
672
+ drop_mask = self.nested_sampler(
673
+ batch_size, num_patches, device,
674
+ inference_with_n_slots=inference_with_n_slots,
675
+ coupled_value=num_slots2use,
676
+ epoch=epoch
677
+ )
678
+ else:
679
+ drop_mask = None
680
+
681
+ if sample:
682
+ return self.sample(slots if not self.use_kl_loss else slots.sample(), drop_mask=drop_mask, targets=targets, cfg=cfg)
683
+
684
+ model_kwargs = dict(autoenc_cond=slots if not self.use_kl_loss else slots.sample(), drop_mask=drop_mask, y=targets)
685
+ if not self.use_sit:
686
+ loss_dict = self.diffusion.training_losses(self.dit, x_vae, t, model_kwargs)
687
+ else:
688
+ loss_dict = self.transport.training_losses(self.dit, x_vae, model_kwargs)
689
+ diff_loss = loss_dict["loss"].mean()
690
+ losses["diff_loss"] = diff_loss
691
+
692
+ if self.use_kl_loss:
693
+ kl_loss = slots.kl()
694
+ losses["kl_loss"] = kl_loss.mean() * self.kl_loss_weight
695
+
696
+ if self.use_repa:
697
+ assert self.dit._repa_hook is not None and z is not None
698
+ z_tilde = self.dit._repa_hook
699
+
700
+ if z_tilde.shape[1] != z.shape[1]:
701
+ z_tilde = interpolate_features(z_tilde, z.shape[1])
702
+
703
+ assert z_tilde.shape[-1] == z.shape[-1], f"Feature dimensions don't match: {z_tilde.shape} vs {z.shape}"
704
+
705
+ z_tilde = F.normalize(z_tilde, dim=-1)
706
+ z = F.normalize(z, dim=-1)
707
+ repa_loss = -torch.sum(z_tilde * z, dim=-1)
708
+ losses["repa_loss"] = repa_loss.mean() * self.repa_loss_weight
709
+
710
+ return losses
711
+
712
+
713
+ def forward(self,
714
+ x,
715
+ targets=None,
716
+ latents=None,
717
+ sample=False,
718
+ epoch=None,
719
+ inference_with_n_slots=-1,
720
+ cfg=1.0):
721
+
722
+ # it will be used in train() and decide whether to set the encoder to eval mode
723
+ if epoch is not None:
724
+ self.current_epoch = epoch
725
+
726
+ if latents is None:
727
+ x_vae = self.vae_encode(x) # (N, C, H, W)
728
+ else:
729
+ x_vae = latents
730
+
731
+ if self.use_repa:
732
+ z = self.repa_encode(x)
733
+ else:
734
+ z = None
735
+
736
+ slots = self.encode_slots(x)
737
+ return self.forward_with_latents(x_vae, slots, z, targets, sample, epoch, inference_with_n_slots, cfg)
738
+
739
+
740
+ @torch.no_grad()
741
+ def sample(self, slots, drop_mask=None, targets=None, cfg=1.0):
742
+ batch_size = slots.shape[0]
743
+ device = slots.device
744
+ z = torch.randn(batch_size, self.dit_in_channels, self.dit_input_size, self.dit_input_size, device=device)
745
+ if cfg != 1.0:
746
+ z = torch.cat([z, z], 0)
747
+ null_slots = self.dit.null_cond.expand(batch_size, -1, -1)
748
+ slots = torch.cat([slots, null_slots], 0)
749
+ if drop_mask is not None:
750
+ null_cond_mask = torch.ones_like(drop_mask)
751
+ drop_mask = torch.cat([drop_mask, null_cond_mask], 0)
752
+ if targets is not None:
753
+ targets = torch.cat([targets, targets], 0)
754
+ model_kwargs = dict(autoenc_cond=slots, drop_mask=drop_mask, y=targets, cfg_scale=cfg)
755
+ sample_fn = self.dit.forward_with_cfg
756
+ else:
757
+ model_kwargs = dict(autoenc_cond=slots, drop_mask=drop_mask, y=targets)
758
+ sample_fn = self.dit.forward
759
+ # Sample images:
760
+ if not self.use_sit:
761
+ samples = self.gen_diffusion.p_sample_loop(
762
+ sample_fn,
763
+ z.shape,
764
+ z,
765
+ clip_denoised=False,
766
+ model_kwargs=model_kwargs,
767
+ progress=False,
768
+ device=device,
769
+ )
770
+ else:
771
+ sde_sample_fn = self.sampler.sample_sde(diffusion_form="sigma")
772
+ samples = sde_sample_fn(z, sample_fn, **model_kwargs)[-1]
773
+ if cfg != 1.0:
774
+ samples, _ = samples.chunk(2, dim=0) # Remove null class samples
775
+ samples = self.vae_decode(samples)
776
+ return samples
777
+
778
+ def cancel_gradients_encoder(self, epoch):
779
+ """Cancel gradients for encoder components after backward pass"""
780
+ if (epoch is not None
781
+ and epoch >= self.freeze_vit_after
782
+ and self.freeze_vit_after != -1):
783
+ # Directly access parameters from the modules
784
+ for p in self.encoder.parameters():
785
+ if p.grad is not None:
786
+ p.grad = None
787
+ for p in self.encoder2slot.parameters():
788
+ if p.grad is not None:
789
+ p.grad = None
790
+
791
+ def train(self, mode=True):
792
+ """Override train() to keep certain components in eval mode"""
793
+ super().train(mode)
794
+ # VAE should always be in eval mode
795
+ self.vae.eval()
796
+
797
+ # Keep encoder in eval mode if frozen
798
+ if (self.freeze_vit_after != -1 and
799
+ hasattr(self, 'current_epoch') and
800
+ self.current_epoch >= self.freeze_vit_after):
801
+ self.encoder.eval()
802
+ self.encoder2slot.eval()
803
+
804
+ # Keep DiT in eval mode if frozen
805
+ if self.freeze_dit:
806
+ self.dit.eval()
807
+
808
+ return self
paintmind/stage1/diffusion/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from . import gaussian_diffusion as gd
7
+ from .respace import SpacedDiffusion, space_timesteps
8
+
9
+
10
+ def create_diffusion(
11
+ timestep_respacing,
12
+ noise_schedule="linear",
13
+ use_kl=False,
14
+ sigma_small=False,
15
+ predict_xstart=False,
16
+ learn_sigma=True,
17
+ rescale_learned_sigmas=False,
18
+ diffusion_steps=1000
19
+ ):
20
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
21
+ if use_kl:
22
+ loss_type = gd.LossType.RESCALED_KL
23
+ elif rescale_learned_sigmas:
24
+ loss_type = gd.LossType.RESCALED_MSE
25
+ else:
26
+ loss_type = gd.LossType.MSE
27
+ if timestep_respacing is None or timestep_respacing == "":
28
+ timestep_respacing = [diffusion_steps]
29
+ return SpacedDiffusion(
30
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
31
+ betas=betas,
32
+ model_mean_type=(
33
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
34
+ ),
35
+ model_var_type=(
36
+ (
37
+ gd.ModelVarType.FIXED_LARGE
38
+ if not sigma_small
39
+ else gd.ModelVarType.FIXED_SMALL
40
+ )
41
+ if not learn_sigma
42
+ else gd.ModelVarType.LEARNED_RANGE
43
+ ),
44
+ loss_type=loss_type
45
+ # rescale_timesteps=rescale_timesteps,
46
+ )
paintmind/stage1/diffusion/diffusion_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import torch as th
7
+ import numpy as np
8
+
9
+
10
+ def normal_kl(mean1, logvar1, mean2, logvar2):
11
+ """
12
+ Compute the KL divergence between two gaussians.
13
+ Shapes are automatically broadcasted, so batches can be compared to
14
+ scalars, among other use cases.
15
+ """
16
+ tensor = None
17
+ for obj in (mean1, logvar1, mean2, logvar2):
18
+ if isinstance(obj, th.Tensor):
19
+ tensor = obj
20
+ break
21
+ assert tensor is not None, "at least one argument must be a Tensor"
22
+
23
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
24
+ # Tensors, but it does not work for th.exp().
25
+ logvar1, logvar2 = [
26
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27
+ for x in (logvar1, logvar2)
28
+ ]
29
+
30
+ return 0.5 * (
31
+ -1.0
32
+ + logvar2
33
+ - logvar1
34
+ + th.exp(logvar1 - logvar2)
35
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36
+ )
37
+
38
+
39
+ def approx_standard_normal_cdf(x):
40
+ """
41
+ A fast approximation of the cumulative distribution function of the
42
+ standard normal.
43
+ """
44
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45
+
46
+
47
+ def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48
+ """
49
+ Compute the log-likelihood of a continuous Gaussian distribution.
50
+ :param x: the targets
51
+ :param means: the Gaussian mean Tensor.
52
+ :param log_scales: the Gaussian log stddev Tensor.
53
+ :return: a tensor like x of log probabilities (in nats).
54
+ """
55
+ centered_x = x - means
56
+ inv_stdv = th.exp(-log_scales)
57
+ normalized_x = centered_x * inv_stdv
58
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59
+ return log_probs
60
+
61
+
62
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63
+ """
64
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
65
+ given image.
66
+ :param x: the target images. It is assumed that this was uint8 values,
67
+ rescaled to the range [-1, 1].
68
+ :param means: the Gaussian mean Tensor.
69
+ :param log_scales: the Gaussian log stddev Tensor.
70
+ :return: a tensor like x of log probabilities (in nats).
71
+ """
72
+ assert x.shape == means.shape == log_scales.shape
73
+ centered_x = x - means
74
+ inv_stdv = th.exp(-log_scales)
75
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76
+ cdf_plus = approx_standard_normal_cdf(plus_in)
77
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78
+ cdf_min = approx_standard_normal_cdf(min_in)
79
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81
+ cdf_delta = cdf_plus - cdf_min
82
+ log_probs = th.where(
83
+ x < -0.999,
84
+ log_cdf_plus,
85
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86
+ )
87
+ assert log_probs.shape == x.shape
88
+ return log_probs
paintmind/stage1/diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,886 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch as th
11
+ import enum
12
+
13
+ from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
14
+
15
+
16
+ def mean_flat(tensor):
17
+ """
18
+ Take the mean over all non-batch dimensions.
19
+ """
20
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
21
+
22
+
23
+ class ModelMeanType(enum.Enum):
24
+ """
25
+ Which type of output the model predicts.
26
+ """
27
+
28
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
29
+ START_X = enum.auto() # the model predicts x_0
30
+ EPSILON = enum.auto() # the model predicts epsilon
31
+
32
+
33
+ class ModelVarType(enum.Enum):
34
+ """
35
+ What is used as the model's output variance.
36
+ The LEARNED_RANGE option has been added to allow the model to predict
37
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
38
+ """
39
+
40
+ LEARNED = enum.auto()
41
+ FIXED_SMALL = enum.auto()
42
+ FIXED_LARGE = enum.auto()
43
+ LEARNED_RANGE = enum.auto()
44
+
45
+
46
+ class LossType(enum.Enum):
47
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
48
+ RESCALED_MSE = (
49
+ enum.auto()
50
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
51
+ KL = enum.auto() # use the variational lower-bound
52
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
53
+
54
+ def is_vb(self):
55
+ return self == LossType.KL or self == LossType.RESCALED_KL
56
+
57
+
58
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
59
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
60
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
61
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
62
+ return betas
63
+
64
+
65
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
66
+ """
67
+ This is the deprecated API for creating beta schedules.
68
+ See get_named_beta_schedule() for the new library of schedules.
69
+ """
70
+ if beta_schedule == "quad":
71
+ betas = (
72
+ np.linspace(
73
+ beta_start ** 0.5,
74
+ beta_end ** 0.5,
75
+ num_diffusion_timesteps,
76
+ dtype=np.float64,
77
+ )
78
+ ** 2
79
+ )
80
+ elif beta_schedule == "linear":
81
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
82
+ elif beta_schedule == "warmup10":
83
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
84
+ elif beta_schedule == "warmup50":
85
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
86
+ elif beta_schedule == "const":
87
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
88
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
89
+ betas = 1.0 / np.linspace(
90
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
91
+ )
92
+ else:
93
+ raise NotImplementedError(beta_schedule)
94
+ assert betas.shape == (num_diffusion_timesteps,)
95
+ return betas
96
+
97
+
98
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
99
+ """
100
+ Get a pre-defined beta schedule for the given name.
101
+ The beta schedule library consists of beta schedules which remain similar
102
+ in the limit of num_diffusion_timesteps.
103
+ Beta schedules may be added, but should not be removed or changed once
104
+ they are committed to maintain backwards compatibility.
105
+ """
106
+ if schedule_name == "linear":
107
+ # Linear schedule from Ho et al, extended to work for any number of
108
+ # diffusion steps.
109
+ scale = 1000 / num_diffusion_timesteps
110
+ return get_beta_schedule(
111
+ "linear",
112
+ beta_start=scale * 0.0001,
113
+ beta_end=scale * 0.02,
114
+ num_diffusion_timesteps=num_diffusion_timesteps,
115
+ )
116
+ elif schedule_name == "cosine":
117
+ return betas_for_alpha_bar(
118
+ num_diffusion_timesteps,
119
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
120
+ )
121
+ else:
122
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
123
+
124
+
125
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
126
+ """
127
+ Create a beta schedule that discretizes the given alpha_t_bar function,
128
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
129
+ :param num_diffusion_timesteps: the number of betas to produce.
130
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
131
+ produces the cumulative product of (1-beta) up to that
132
+ part of the diffusion process.
133
+ :param max_beta: the maximum beta to use; use values lower than 1 to
134
+ prevent singularities.
135
+ """
136
+ betas = []
137
+ for i in range(num_diffusion_timesteps):
138
+ t1 = i / num_diffusion_timesteps
139
+ t2 = (i + 1) / num_diffusion_timesteps
140
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
141
+ return np.array(betas)
142
+
143
+
144
+ class GaussianDiffusion:
145
+ """
146
+ Utilities for training and sampling diffusion models.
147
+ Original ported from this codebase:
148
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
149
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
150
+ starting at T and going to 1.
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ *,
156
+ betas,
157
+ model_mean_type,
158
+ model_var_type,
159
+ loss_type
160
+ ):
161
+
162
+ self.model_mean_type = model_mean_type
163
+ self.model_var_type = model_var_type
164
+ self.loss_type = loss_type
165
+
166
+ # Use float64 for accuracy.
167
+ betas = np.array(betas, dtype=np.float64)
168
+ self.betas = betas
169
+ assert len(betas.shape) == 1, "betas must be 1-D"
170
+ assert (betas > 0).all() and (betas <= 1).all()
171
+
172
+ self.num_timesteps = int(betas.shape[0])
173
+
174
+ alphas = 1.0 - betas
175
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
176
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
177
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
178
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
179
+
180
+ # calculations for diffusion q(x_t | x_{t-1}) and others
181
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
182
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
183
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
184
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
185
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
186
+
187
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
188
+ self.posterior_variance = (
189
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
190
+ )
191
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
192
+ self.posterior_log_variance_clipped = np.log(
193
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
194
+ ) if len(self.posterior_variance) > 1 else np.array([])
195
+
196
+ self.posterior_mean_coef1 = (
197
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
198
+ )
199
+ self.posterior_mean_coef2 = (
200
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
201
+ )
202
+
203
+ def q_mean_variance(self, x_start, t):
204
+ """
205
+ Get the distribution q(x_t | x_0).
206
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
207
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
208
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
209
+ """
210
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
211
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
212
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
213
+ return mean, variance, log_variance
214
+
215
+ def q_sample(self, x_start, t, noise=None):
216
+ """
217
+ Diffuse the data for a given number of diffusion steps.
218
+ In other words, sample from q(x_t | x_0).
219
+ :param x_start: the initial data batch.
220
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
221
+ :param noise: if specified, the split-out normal noise.
222
+ :return: A noisy version of x_start.
223
+ """
224
+ if noise is None:
225
+ noise = th.randn_like(x_start)
226
+ assert noise.shape == x_start.shape
227
+ return (
228
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
229
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
230
+ )
231
+
232
+ def q_posterior_mean_variance(self, x_start, x_t, t):
233
+ """
234
+ Compute the mean and variance of the diffusion posterior:
235
+ q(x_{t-1} | x_t, x_0)
236
+ """
237
+ assert x_start.shape == x_t.shape
238
+ posterior_mean = (
239
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
240
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
241
+ )
242
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
243
+ posterior_log_variance_clipped = _extract_into_tensor(
244
+ self.posterior_log_variance_clipped, t, x_t.shape
245
+ )
246
+ assert (
247
+ posterior_mean.shape[0]
248
+ == posterior_variance.shape[0]
249
+ == posterior_log_variance_clipped.shape[0]
250
+ == x_start.shape[0]
251
+ )
252
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
253
+
254
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
255
+ """
256
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
257
+ the initial x, x_0.
258
+ :param model: the model, which takes a signal and a batch of timesteps
259
+ as input.
260
+ :param x: the [N x C x ...] tensor at time t.
261
+ :param t: a 1-D Tensor of timesteps.
262
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
263
+ :param denoised_fn: if not None, a function which applies to the
264
+ x_start prediction before it is used to sample. Applies before
265
+ clip_denoised.
266
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
267
+ pass to the model. This can be used for conditioning.
268
+ :return: a dict with the following keys:
269
+ - 'mean': the model mean output.
270
+ - 'variance': the model variance output.
271
+ - 'log_variance': the log of 'variance'.
272
+ - 'pred_xstart': the prediction for x_0.
273
+ """
274
+ if model_kwargs is None:
275
+ model_kwargs = {}
276
+
277
+ B, C = x.shape[:2]
278
+ assert t.shape == (B,)
279
+ model_output = model(x, t, **model_kwargs)
280
+ if isinstance(model_output, tuple):
281
+ model_output, extra = model_output
282
+ else:
283
+ extra = None
284
+
285
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
286
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
287
+ model_output, model_var_values = th.split(model_output, C, dim=1)
288
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
289
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
290
+ # The model_var_values is [-1, 1] for [min_var, max_var].
291
+ frac = (model_var_values + 1) / 2
292
+ model_log_variance = frac * max_log + (1 - frac) * min_log
293
+ model_variance = th.exp(model_log_variance)
294
+ else:
295
+ model_variance, model_log_variance = {
296
+ # for fixedlarge, we set the initial (log-)variance like so
297
+ # to get a better decoder log likelihood.
298
+ ModelVarType.FIXED_LARGE: (
299
+ np.append(self.posterior_variance[1], self.betas[1:]),
300
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
301
+ ),
302
+ ModelVarType.FIXED_SMALL: (
303
+ self.posterior_variance,
304
+ self.posterior_log_variance_clipped,
305
+ ),
306
+ }[self.model_var_type]
307
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
308
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
309
+
310
+ def process_xstart(x):
311
+ if denoised_fn is not None:
312
+ x = denoised_fn(x)
313
+ if clip_denoised:
314
+ return x.clamp(-1, 1)
315
+ return x
316
+
317
+ if self.model_mean_type == ModelMeanType.START_X:
318
+ pred_xstart = process_xstart(model_output)
319
+ else:
320
+ pred_xstart = process_xstart(
321
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
322
+ )
323
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
324
+
325
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
326
+ return {
327
+ "mean": model_mean,
328
+ "variance": model_variance,
329
+ "log_variance": model_log_variance,
330
+ "pred_xstart": pred_xstart,
331
+ "extra": extra,
332
+ }
333
+
334
+ def _predict_xstart_from_eps(self, x_t, t, eps):
335
+ assert x_t.shape == eps.shape
336
+ return (
337
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
338
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
339
+ )
340
+
341
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
342
+ return (
343
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
344
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
345
+
346
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
347
+ """
348
+ Compute the mean for the previous step, given a function cond_fn that
349
+ computes the gradient of a conditional log probability with respect to
350
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
351
+ condition on y.
352
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
353
+ """
354
+ gradient = cond_fn(x, t, **model_kwargs)
355
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
356
+ return new_mean
357
+
358
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
359
+ """
360
+ Compute what the p_mean_variance output would have been, should the
361
+ model's score function be conditioned by cond_fn.
362
+ See condition_mean() for details on cond_fn.
363
+ Unlike condition_mean(), this instead uses the conditioning strategy
364
+ from Song et al (2020).
365
+ """
366
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
367
+
368
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
369
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
370
+
371
+ out = p_mean_var.copy()
372
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
373
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
374
+ return out
375
+
376
+ def p_sample(
377
+ self,
378
+ model,
379
+ x,
380
+ t,
381
+ clip_denoised=True,
382
+ denoised_fn=None,
383
+ cond_fn=None,
384
+ model_kwargs=None,
385
+ temperature=1.0
386
+ ):
387
+ """
388
+ Sample x_{t-1} from the model at the given timestep.
389
+ :param model: the model to sample from.
390
+ :param x: the current tensor at x_{t-1}.
391
+ :param t: the value of t, starting at 0 for the first diffusion step.
392
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
393
+ :param denoised_fn: if not None, a function which applies to the
394
+ x_start prediction before it is used to sample.
395
+ :param cond_fn: if not None, this is a gradient function that acts
396
+ similarly to the model.
397
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
398
+ pass to the model. This can be used for conditioning.
399
+ :param temperature: temperature scaling during Diff Loss sampling.
400
+ :return: a dict containing the following keys:
401
+ - 'sample': a random sample from the model.
402
+ - 'pred_xstart': a prediction of x_0.
403
+ """
404
+ out = self.p_mean_variance(
405
+ model,
406
+ x,
407
+ t,
408
+ clip_denoised=clip_denoised,
409
+ denoised_fn=denoised_fn,
410
+ model_kwargs=model_kwargs,
411
+ )
412
+ noise = th.randn_like(x)
413
+ nonzero_mask = (
414
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
415
+ ) # no noise when t == 0
416
+ if cond_fn is not None:
417
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
418
+ # scale the noise by temperature
419
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise * temperature
420
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
421
+
422
+ def p_sample_loop(
423
+ self,
424
+ model,
425
+ shape,
426
+ noise=None,
427
+ clip_denoised=True,
428
+ denoised_fn=None,
429
+ cond_fn=None,
430
+ model_kwargs=None,
431
+ device=None,
432
+ progress=False,
433
+ temperature=1.0,
434
+ ):
435
+ """
436
+ Generate samples from the model.
437
+ :param model: the model module.
438
+ :param shape: the shape of the samples, (N, C, H, W).
439
+ :param noise: if specified, the noise from the encoder to sample.
440
+ Should be of the same shape as `shape`.
441
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
442
+ :param denoised_fn: if not None, a function which applies to the
443
+ x_start prediction before it is used to sample.
444
+ :param cond_fn: if not None, this is a gradient function that acts
445
+ similarly to the model.
446
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
447
+ pass to the model. This can be used for conditioning.
448
+ :param device: if specified, the device to create the samples on.
449
+ If not specified, use a model parameter's device.
450
+ :param progress: if True, show a tqdm progress bar.
451
+ :param temperature: temperature scaling during Diff Loss sampling.
452
+ :return: a non-differentiable batch of samples.
453
+ """
454
+ final = None
455
+ for sample in self.p_sample_loop_progressive(
456
+ model,
457
+ shape,
458
+ noise=noise,
459
+ clip_denoised=clip_denoised,
460
+ denoised_fn=denoised_fn,
461
+ cond_fn=cond_fn,
462
+ model_kwargs=model_kwargs,
463
+ device=device,
464
+ progress=progress,
465
+ temperature=temperature,
466
+ ):
467
+ final = sample
468
+ return final["sample"]
469
+
470
+ def p_sample_loop_progressive(
471
+ self,
472
+ model,
473
+ shape,
474
+ noise=None,
475
+ clip_denoised=True,
476
+ denoised_fn=None,
477
+ cond_fn=None,
478
+ model_kwargs=None,
479
+ device=None,
480
+ progress=False,
481
+ temperature=1.0,
482
+ ):
483
+ """
484
+ Generate samples from the model and yield intermediate samples from
485
+ each timestep of diffusion.
486
+ Arguments are the same as p_sample_loop().
487
+ Returns a generator over dicts, where each dict is the return value of
488
+ p_sample().
489
+ """
490
+ if device is None:
491
+ device = next(model.parameters()).device
492
+ assert isinstance(shape, (tuple, list))
493
+ if noise is not None:
494
+ img = noise
495
+ else:
496
+ img = th.randn(*shape, device=device)
497
+ indices = list(range(self.num_timesteps))[::-1]
498
+
499
+ if progress:
500
+ # Lazy import so that we don't depend on tqdm.
501
+ from tqdm.auto import tqdm
502
+
503
+ indices = tqdm(indices)
504
+
505
+ for i in indices:
506
+ t = th.tensor([i] * shape[0], device=device)
507
+ with th.no_grad():
508
+ out = self.p_sample(
509
+ model,
510
+ img,
511
+ t,
512
+ clip_denoised=clip_denoised,
513
+ denoised_fn=denoised_fn,
514
+ cond_fn=cond_fn,
515
+ model_kwargs=model_kwargs,
516
+ temperature=temperature,
517
+ )
518
+ yield out
519
+ img = out["sample"]
520
+
521
+ def ddim_sample(
522
+ self,
523
+ model,
524
+ x,
525
+ t,
526
+ clip_denoised=True,
527
+ denoised_fn=None,
528
+ cond_fn=None,
529
+ model_kwargs=None,
530
+ eta=0.0,
531
+ ):
532
+ """
533
+ Sample x_{t-1} from the model using DDIM.
534
+ Same usage as p_sample().
535
+ """
536
+ out = self.p_mean_variance(
537
+ model,
538
+ x,
539
+ t,
540
+ clip_denoised=clip_denoised,
541
+ denoised_fn=denoised_fn,
542
+ model_kwargs=model_kwargs,
543
+ )
544
+ if cond_fn is not None:
545
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
546
+
547
+ # Usually our model outputs epsilon, but we re-derive it
548
+ # in case we used x_start or x_prev prediction.
549
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
550
+
551
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
552
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
553
+ sigma = (
554
+ eta
555
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
556
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
557
+ )
558
+ # Equation 12.
559
+ noise = th.randn_like(x)
560
+ mean_pred = (
561
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
562
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
563
+ )
564
+ nonzero_mask = (
565
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
566
+ ) # no noise when t == 0
567
+ sample = mean_pred + nonzero_mask * sigma * noise
568
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
569
+
570
+ def ddim_reverse_sample(
571
+ self,
572
+ model,
573
+ x,
574
+ t,
575
+ clip_denoised=True,
576
+ denoised_fn=None,
577
+ cond_fn=None,
578
+ model_kwargs=None,
579
+ eta=0.0,
580
+ ):
581
+ """
582
+ Sample x_{t+1} from the model using DDIM reverse ODE.
583
+ """
584
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
585
+ out = self.p_mean_variance(
586
+ model,
587
+ x,
588
+ t,
589
+ clip_denoised=clip_denoised,
590
+ denoised_fn=denoised_fn,
591
+ model_kwargs=model_kwargs,
592
+ )
593
+ if cond_fn is not None:
594
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
595
+ # Usually our model outputs epsilon, but we re-derive it
596
+ # in case we used x_start or x_prev prediction.
597
+ eps = (
598
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
599
+ - out["pred_xstart"]
600
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
601
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
602
+
603
+ # Equation 12. reversed
604
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
605
+
606
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
607
+
608
+ def ddim_sample_loop(
609
+ self,
610
+ model,
611
+ shape,
612
+ noise=None,
613
+ clip_denoised=True,
614
+ denoised_fn=None,
615
+ cond_fn=None,
616
+ model_kwargs=None,
617
+ device=None,
618
+ progress=False,
619
+ eta=0.0,
620
+ ):
621
+ """
622
+ Generate samples from the model using DDIM.
623
+ Same usage as p_sample_loop().
624
+ """
625
+ final = None
626
+ for sample in self.ddim_sample_loop_progressive(
627
+ model,
628
+ shape,
629
+ noise=noise,
630
+ clip_denoised=clip_denoised,
631
+ denoised_fn=denoised_fn,
632
+ cond_fn=cond_fn,
633
+ model_kwargs=model_kwargs,
634
+ device=device,
635
+ progress=progress,
636
+ eta=eta,
637
+ ):
638
+ final = sample
639
+ return final["sample"]
640
+
641
+ def ddim_sample_loop_progressive(
642
+ self,
643
+ model,
644
+ shape,
645
+ noise=None,
646
+ clip_denoised=True,
647
+ denoised_fn=None,
648
+ cond_fn=None,
649
+ model_kwargs=None,
650
+ device=None,
651
+ progress=False,
652
+ eta=0.0,
653
+ ):
654
+ """
655
+ Use DDIM to sample from the model and yield intermediate samples from
656
+ each timestep of DDIM.
657
+ Same usage as p_sample_loop_progressive().
658
+ """
659
+ if device is None:
660
+ device = next(model.parameters()).device
661
+ assert isinstance(shape, (tuple, list))
662
+ if noise is not None:
663
+ img = noise
664
+ else:
665
+ img = th.randn(*shape, device=device)
666
+ indices = list(range(self.num_timesteps))[::-1]
667
+
668
+ if progress:
669
+ # Lazy import so that we don't depend on tqdm.
670
+ from tqdm.auto import tqdm
671
+
672
+ indices = tqdm(indices)
673
+
674
+ for i in indices:
675
+ t = th.tensor([i] * shape[0], device=device)
676
+ with th.no_grad():
677
+ out = self.ddim_sample(
678
+ model,
679
+ img,
680
+ t,
681
+ clip_denoised=clip_denoised,
682
+ denoised_fn=denoised_fn,
683
+ cond_fn=cond_fn,
684
+ model_kwargs=model_kwargs,
685
+ eta=eta,
686
+ )
687
+ yield out
688
+ img = out["sample"]
689
+
690
+ def _vb_terms_bpd(
691
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
692
+ ):
693
+ """
694
+ Get a term for the variational lower-bound.
695
+ The resulting units are bits (rather than nats, as one might expect).
696
+ This allows for comparison to other papers.
697
+ :return: a dict with the following keys:
698
+ - 'output': a shape [N] tensor of NLLs or KLs.
699
+ - 'pred_xstart': the x_0 predictions.
700
+ """
701
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
702
+ x_start=x_start, x_t=x_t, t=t
703
+ )
704
+ out = self.p_mean_variance(
705
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
706
+ )
707
+ kl = normal_kl(
708
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
709
+ )
710
+ kl = mean_flat(kl) / np.log(2.0)
711
+
712
+ decoder_nll = -discretized_gaussian_log_likelihood(
713
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
714
+ )
715
+ assert decoder_nll.shape == x_start.shape
716
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
717
+
718
+ # At the first timestep return the decoder NLL,
719
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
720
+ output = th.where((t == 0), decoder_nll, kl)
721
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
722
+
723
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
724
+ """
725
+ Compute training losses for a single timestep.
726
+ :param model: the model to evaluate loss on.
727
+ :param x_start: the [N x C x ...] tensor of inputs.
728
+ :param t: a batch of timestep indices.
729
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
730
+ pass to the model. This can be used for conditioning.
731
+ :param noise: if specified, the specific Gaussian noise to try to remove.
732
+ :return: a dict with the key "loss" containing a tensor of shape [N].
733
+ Some mean or variance settings may also have other keys.
734
+ """
735
+ if model_kwargs is None:
736
+ model_kwargs = {}
737
+ if noise is None:
738
+ noise = th.randn_like(x_start)
739
+ x_t = self.q_sample(x_start, t, noise=noise)
740
+
741
+ terms = {}
742
+
743
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
744
+ terms["loss"] = self._vb_terms_bpd(
745
+ model=model,
746
+ x_start=x_start,
747
+ x_t=x_t,
748
+ t=t,
749
+ clip_denoised=False,
750
+ model_kwargs=model_kwargs,
751
+ )["output"]
752
+ if self.loss_type == LossType.RESCALED_KL:
753
+ terms["loss"] *= self.num_timesteps
754
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
755
+ model_output = model(x_t, t, **model_kwargs)
756
+
757
+ if self.model_var_type in [
758
+ ModelVarType.LEARNED,
759
+ ModelVarType.LEARNED_RANGE,
760
+ ]:
761
+ B, C = x_t.shape[:2]
762
+ if len(model_output.shape) == len(x_t.shape) + 1:
763
+ x_t = x_t.unsqueeze(-1).expand(*([-1] * (len(x_t.shape))), model_output.shape[-1])
764
+ x_start = x_start.unsqueeze(-1).expand(*([-1] * (len(x_start.shape))), model_output.shape[-1])
765
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
766
+ model_output, model_var_values = th.split(model_output, C, dim=1)
767
+ # Learn the variance using the variational bound, but don't let
768
+ # it affect our mean prediction.
769
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
770
+ terms["vb"] = self._vb_terms_bpd(
771
+ model=lambda *args, r=frozen_out: r,
772
+ x_start=x_start,
773
+ x_t=x_t,
774
+ t=t,
775
+ clip_denoised=False,
776
+ )["output"]
777
+ if self.loss_type == LossType.RESCALED_MSE:
778
+ # Divide by 1000 for equivalence with initial implementation.
779
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
780
+ terms["vb"] *= self.num_timesteps / 1000.0
781
+
782
+ target = {
783
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
784
+ x_start=x_start, x_t=x_t, t=t
785
+ )[0],
786
+ ModelMeanType.START_X: x_start,
787
+ ModelMeanType.EPSILON: noise,
788
+ }[self.model_mean_type]
789
+ if len(model_output.shape) == len(target.shape) + 1:
790
+ target = target.unsqueeze(-1).expand(*([-1] * (len(target.shape))), model_output.shape[-1])
791
+ assert model_output.shape == target.shape == x_start.shape
792
+ terms["mse"] = mean_flat((target - model_output) ** 2)
793
+ if "vb" in terms:
794
+ terms["loss"] = terms["mse"] + terms["vb"]
795
+ else:
796
+ terms["loss"] = terms["mse"]
797
+ else:
798
+ raise NotImplementedError(self.loss_type)
799
+
800
+ return terms
801
+
802
+ def _prior_bpd(self, x_start):
803
+ """
804
+ Get the prior KL term for the variational lower-bound, measured in
805
+ bits-per-dim.
806
+ This term can't be optimized, as it only depends on the encoder.
807
+ :param x_start: the [N x C x ...] tensor of inputs.
808
+ :return: a batch of [N] KL values (in bits), one per batch element.
809
+ """
810
+ batch_size = x_start.shape[0]
811
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
812
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
813
+ kl_prior = normal_kl(
814
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
815
+ )
816
+ return mean_flat(kl_prior) / np.log(2.0)
817
+
818
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
819
+ """
820
+ Compute the entire variational lower-bound, measured in bits-per-dim,
821
+ as well as other related quantities.
822
+ :param model: the model to evaluate loss on.
823
+ :param x_start: the [N x C x ...] tensor of inputs.
824
+ :param clip_denoised: if True, clip denoised samples.
825
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
826
+ pass to the model. This can be used for conditioning.
827
+ :return: a dict containing the following keys:
828
+ - total_bpd: the total variational lower-bound, per batch element.
829
+ - prior_bpd: the prior term in the lower-bound.
830
+ - vb: an [N x T] tensor of terms in the lower-bound.
831
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
832
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
833
+ """
834
+ device = x_start.device
835
+ batch_size = x_start.shape[0]
836
+
837
+ vb = []
838
+ xstart_mse = []
839
+ mse = []
840
+ for t in list(range(self.num_timesteps))[::-1]:
841
+ t_batch = th.tensor([t] * batch_size, device=device)
842
+ noise = th.randn_like(x_start)
843
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
844
+ # Calculate VLB term at the current timestep
845
+ with th.no_grad():
846
+ out = self._vb_terms_bpd(
847
+ model,
848
+ x_start=x_start,
849
+ x_t=x_t,
850
+ t=t_batch,
851
+ clip_denoised=clip_denoised,
852
+ model_kwargs=model_kwargs,
853
+ )
854
+ vb.append(out["output"])
855
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
856
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
857
+ mse.append(mean_flat((eps - noise) ** 2))
858
+
859
+ vb = th.stack(vb, dim=1)
860
+ xstart_mse = th.stack(xstart_mse, dim=1)
861
+ mse = th.stack(mse, dim=1)
862
+
863
+ prior_bpd = self._prior_bpd(x_start)
864
+ total_bpd = vb.sum(dim=1) + prior_bpd
865
+ return {
866
+ "total_bpd": total_bpd,
867
+ "prior_bpd": prior_bpd,
868
+ "vb": vb,
869
+ "xstart_mse": xstart_mse,
870
+ "mse": mse,
871
+ }
872
+
873
+
874
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
875
+ """
876
+ Extract values from a 1-D numpy array for a batch of indices.
877
+ :param arr: the 1-D numpy array.
878
+ :param timesteps: a tensor of indices into the array to extract.
879
+ :param broadcast_shape: a larger shape of K dimensions with the batch
880
+ dimension equal to the length of timesteps.
881
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
882
+ """
883
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
884
+ while len(res.shape) < len(broadcast_shape):
885
+ res = res[..., None]
886
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
paintmind/stage1/diffusion/respace.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import numpy as np
7
+ import torch as th
8
+
9
+ from .gaussian_diffusion import GaussianDiffusion
10
+
11
+
12
+ def space_timesteps(num_timesteps, section_counts):
13
+ """
14
+ Create a list of timesteps to use from an original diffusion process,
15
+ given the number of timesteps we want to take from equally-sized portions
16
+ of the original process.
17
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
18
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
19
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
20
+ If the stride is a string starting with "ddim", then the fixed striding
21
+ from the DDIM paper is used, and only one section is allowed.
22
+ :param num_timesteps: the number of diffusion steps in the original
23
+ process to divide up.
24
+ :param section_counts: either a list of numbers, or a string containing
25
+ comma-separated numbers, indicating the step count
26
+ per section. As a special case, use "ddimN" where N
27
+ is a number of steps to use the striding from the
28
+ DDIM paper.
29
+ :return: a set of diffusion steps from the original process to use.
30
+ """
31
+ if isinstance(section_counts, str):
32
+ if section_counts.startswith("ddim"):
33
+ desired_count = int(section_counts[len("ddim") :])
34
+ for i in range(1, num_timesteps):
35
+ if len(range(0, num_timesteps, i)) == desired_count:
36
+ return set(range(0, num_timesteps, i))
37
+ raise ValueError(
38
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
39
+ )
40
+ section_counts = [int(x) for x in section_counts.split(",")]
41
+ size_per = num_timesteps // len(section_counts)
42
+ extra = num_timesteps % len(section_counts)
43
+ start_idx = 0
44
+ all_steps = []
45
+ for i, section_count in enumerate(section_counts):
46
+ size = size_per + (1 if i < extra else 0)
47
+ if size < section_count:
48
+ raise ValueError(
49
+ f"cannot divide section of {size} steps into {section_count}"
50
+ )
51
+ if section_count <= 1:
52
+ frac_stride = 1
53
+ else:
54
+ frac_stride = (size - 1) / (section_count - 1)
55
+ cur_idx = 0.0
56
+ taken_steps = []
57
+ for _ in range(section_count):
58
+ taken_steps.append(start_idx + round(cur_idx))
59
+ cur_idx += frac_stride
60
+ all_steps += taken_steps
61
+ start_idx += size
62
+ return set(all_steps)
63
+
64
+
65
+ class SpacedDiffusion(GaussianDiffusion):
66
+ """
67
+ A diffusion process which can skip steps in a base diffusion process.
68
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
69
+ original diffusion process to retain.
70
+ :param kwargs: the kwargs to create the base diffusion process.
71
+ """
72
+
73
+ def __init__(self, use_timesteps, **kwargs):
74
+ self.use_timesteps = set(use_timesteps)
75
+ self.timestep_map = []
76
+ self.original_num_steps = len(kwargs["betas"])
77
+
78
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79
+ last_alpha_cumprod = 1.0
80
+ new_betas = []
81
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82
+ if i in self.use_timesteps:
83
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84
+ last_alpha_cumprod = alpha_cumprod
85
+ self.timestep_map.append(i)
86
+ kwargs["betas"] = np.array(new_betas)
87
+ super().__init__(**kwargs)
88
+
89
+ def p_mean_variance(
90
+ self, model, *args, **kwargs
91
+ ): # pylint: disable=signature-differs
92
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93
+
94
+ def training_losses(
95
+ self, model, *args, **kwargs
96
+ ): # pylint: disable=signature-differs
97
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
98
+
99
+ def condition_mean(self, cond_fn, *args, **kwargs):
100
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101
+
102
+ def condition_score(self, cond_fn, *args, **kwargs):
103
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104
+
105
+ def _wrap_model(self, model):
106
+ if isinstance(model, _WrappedModel):
107
+ return model
108
+ return _WrappedModel(
109
+ model, self.timestep_map, self.original_num_steps
110
+ )
111
+
112
+ def _scale_timesteps(self, t):
113
+ # Scaling is done by the wrapped model.
114
+ return t
115
+
116
+
117
+ class _WrappedModel:
118
+ def __init__(self, model, timestep_map, original_num_steps):
119
+ self.model = model
120
+ self.timestep_map = timestep_map
121
+ # self.rescale_timesteps = rescale_timesteps
122
+ self.original_num_steps = original_num_steps
123
+
124
+ def __call__(self, x, ts, **kwargs):
125
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126
+ new_ts = map_tensor[ts]
127
+ # if self.rescale_timesteps:
128
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129
+ return self.model(x, new_ts, **kwargs)
130
+
paintmind/stage1/diffusion/timestep_sampler.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+ import numpy as np
9
+ import torch as th
10
+ import torch.distributed as dist
11
+
12
+
13
+ def create_named_schedule_sampler(name, diffusion):
14
+ """
15
+ Create a ScheduleSampler from a library of pre-defined samplers.
16
+ :param name: the name of the sampler.
17
+ :param diffusion: the diffusion object to sample for.
18
+ """
19
+ if name == "uniform":
20
+ return UniformSampler(diffusion)
21
+ elif name == "loss-second-moment":
22
+ return LossSecondMomentResampler(diffusion)
23
+ else:
24
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
25
+
26
+
27
+ class ScheduleSampler(ABC):
28
+ """
29
+ A distribution over timesteps in the diffusion process, intended to reduce
30
+ variance of the objective.
31
+ By default, samplers perform unbiased importance sampling, in which the
32
+ objective's mean is unchanged.
33
+ However, subclasses may override sample() to change how the resampled
34
+ terms are reweighted, allowing for actual changes in the objective.
35
+ """
36
+
37
+ @abstractmethod
38
+ def weights(self):
39
+ """
40
+ Get a numpy array of weights, one per diffusion step.
41
+ The weights needn't be normalized, but must be positive.
42
+ """
43
+
44
+ def sample(self, batch_size, device):
45
+ """
46
+ Importance-sample timesteps for a batch.
47
+ :param batch_size: the number of timesteps.
48
+ :param device: the torch device to save to.
49
+ :return: a tuple (timesteps, weights):
50
+ - timesteps: a tensor of timestep indices.
51
+ - weights: a tensor of weights to scale the resulting losses.
52
+ """
53
+ w = self.weights()
54
+ p = w / np.sum(w)
55
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56
+ indices = th.from_numpy(indices_np).long().to(device)
57
+ weights_np = 1 / (len(p) * p[indices_np])
58
+ weights = th.from_numpy(weights_np).float().to(device)
59
+ return indices, weights
60
+
61
+
62
+ class UniformSampler(ScheduleSampler):
63
+ def __init__(self, diffusion):
64
+ self.diffusion = diffusion
65
+ self._weights = np.ones([diffusion.num_timesteps])
66
+
67
+ def weights(self):
68
+ return self._weights
69
+
70
+
71
+ class LossAwareSampler(ScheduleSampler):
72
+ def update_with_local_losses(self, local_ts, local_losses):
73
+ """
74
+ Update the reweighting using losses from a model.
75
+ Call this method from each rank with a batch of timesteps and the
76
+ corresponding losses for each of those timesteps.
77
+ This method will perform synchronization to make sure all of the ranks
78
+ maintain the exact same reweighting.
79
+ :param local_ts: an integer Tensor of timesteps.
80
+ :param local_losses: a 1D Tensor of losses.
81
+ """
82
+ batch_sizes = [
83
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
84
+ for _ in range(dist.get_world_size())
85
+ ]
86
+ dist.all_gather(
87
+ batch_sizes,
88
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89
+ )
90
+
91
+ # Pad all_gather batches to be the maximum batch size.
92
+ batch_sizes = [x.item() for x in batch_sizes]
93
+ max_bs = max(batch_sizes)
94
+
95
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97
+ dist.all_gather(timestep_batches, local_ts)
98
+ dist.all_gather(loss_batches, local_losses)
99
+ timesteps = [
100
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101
+ ]
102
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103
+ self.update_with_all_losses(timesteps, losses)
104
+
105
+ @abstractmethod
106
+ def update_with_all_losses(self, ts, losses):
107
+ """
108
+ Update the reweighting using losses from a model.
109
+ Sub-classes should override this method to update the reweighting
110
+ using losses from the model.
111
+ This method directly updates the reweighting without synchronizing
112
+ between workers. It is called by update_with_local_losses from all
113
+ ranks with identical arguments. Thus, it should have deterministic
114
+ behavior to maintain state across workers.
115
+ :param ts: a list of int timesteps.
116
+ :param losses: a list of float losses, one per timestep.
117
+ """
118
+
119
+
120
+ class LossSecondMomentResampler(LossAwareSampler):
121
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122
+ self.diffusion = diffusion
123
+ self.history_per_term = history_per_term
124
+ self.uniform_prob = uniform_prob
125
+ self._loss_history = np.zeros(
126
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
127
+ )
128
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129
+
130
+ def weights(self):
131
+ if not self._warmed_up():
132
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134
+ weights /= np.sum(weights)
135
+ weights *= 1 - self.uniform_prob
136
+ weights += self.uniform_prob / len(weights)
137
+ return weights
138
+
139
+ def update_with_all_losses(self, ts, losses):
140
+ for t, loss in zip(ts, losses):
141
+ if self._loss_counts[t] == self.history_per_term:
142
+ # Shift out the oldest loss term.
143
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
144
+ self._loss_history[t, -1] = loss
145
+ else:
146
+ self._loss_history[t, self._loss_counts[t]] = loss
147
+ self._loss_counts[t] += 1
148
+
149
+ def _warmed_up(self):
150
+ return (self._loss_counts == self.history_per_term).all()
paintmind/stage1/diffusion_transfomers.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # GLIDE: https://github.com/openai/glide-text2im
9
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
+ # --------------------------------------------------------
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+ import math
16
+ from timm.models.vision_transformer import PatchEmbed, Mlp
17
+ from paintmind.stage1.fused_attention import Attention
18
+
19
+
20
+
21
+ def modulate(x, shift, scale):
22
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
23
+
24
+
25
+ #################################################################################
26
+ # Embedding Layers for Timesteps and Class Labels #
27
+ #################################################################################
28
+
29
+ class TimestepEmbedder(nn.Module):
30
+ """
31
+ Embeds scalar timesteps into vector representations.
32
+ """
33
+ def __init__(self, hidden_size, frequency_embedding_size=256):
34
+ super().__init__()
35
+ self.mlp = nn.Sequential(
36
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
37
+ nn.SiLU(),
38
+ nn.Linear(hidden_size, hidden_size, bias=True),
39
+ )
40
+ self.frequency_embedding_size = frequency_embedding_size
41
+
42
+ @staticmethod
43
+ def timestep_embedding(t, dim, max_period=10000):
44
+ """
45
+ Create sinusoidal timestep embeddings.
46
+ :param t: a 1-D Tensor of N indices, one per batch element.
47
+ These may be fractional.
48
+ :param dim: the dimension of the output.
49
+ :param max_period: controls the minimum frequency of the embeddings.
50
+ :return: an (N, D) Tensor of positional embeddings.
51
+ """
52
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
53
+ half = dim // 2
54
+ freqs = torch.exp(
55
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
56
+ ).to(device=t.device)
57
+ args = t[:, None].float() * freqs[None]
58
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
59
+ if dim % 2:
60
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
61
+ return embedding
62
+
63
+ def forward(self, t):
64
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
65
+ t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
66
+ return t_emb
67
+
68
+
69
+ class LabelEmbedder(nn.Module):
70
+ """
71
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
72
+ """
73
+ def __init__(self, num_classes, hidden_size, dropout_prob):
74
+ super().__init__()
75
+ use_cfg_embedding = dropout_prob > 0
76
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
77
+ self.num_classes = num_classes
78
+ self.dropout_prob = dropout_prob
79
+
80
+ def token_drop(self, labels, force_drop_ids=None):
81
+ """
82
+ Drops labels to enable classifier-free guidance.
83
+ """
84
+ if force_drop_ids is None:
85
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
86
+ else:
87
+ drop_ids = force_drop_ids == 1
88
+ labels = torch.where(drop_ids, self.num_classes, labels)
89
+ return labels
90
+
91
+ def forward(self, labels, train, force_drop_ids=None):
92
+ use_dropout = self.dropout_prob > 0
93
+ if (train and use_dropout) or (force_drop_ids is not None):
94
+ labels = self.token_drop(labels, force_drop_ids)
95
+ embeddings = self.embedding_table(labels)
96
+ return embeddings
97
+
98
+
99
+ #################################################################################
100
+ # Core DiT Model #
101
+ #################################################################################
102
+
103
+ class DiTBlock(nn.Module):
104
+ """
105
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
106
+ """
107
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
108
+ super().__init__()
109
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
110
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
111
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
112
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
113
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
114
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
115
+ self.adaLN_modulation = nn.Sequential(
116
+ nn.SiLU(),
117
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
118
+ )
119
+
120
+ def forward(self, x, c, mask=None):
121
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
122
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), mask)
123
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
124
+ return x
125
+
126
+
127
+ class FinalLayer(nn.Module):
128
+ """
129
+ The final layer of DiT.
130
+ """
131
+ def __init__(self, hidden_size, patch_size, out_channels):
132
+ super().__init__()
133
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
134
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
135
+ self.adaLN_modulation = nn.Sequential(
136
+ nn.SiLU(),
137
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
138
+ )
139
+
140
+ def forward(self, x, c):
141
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
142
+ x = modulate(self.norm_final(x), shift, scale)
143
+ x = self.linear(x)
144
+ return x
145
+
146
+
147
+ class DiT(nn.Module):
148
+ """
149
+ Diffusion model with a Transformer backbone.
150
+ """
151
+ def __init__(
152
+ self,
153
+ input_size=32,
154
+ patch_size=2,
155
+ in_channels=4,
156
+ hidden_size=1152,
157
+ depth=28,
158
+ num_heads=16,
159
+ mlp_ratio=4.0,
160
+ class_dropout_prob=0.1,
161
+ num_classes=1000,
162
+ learn_sigma=True,
163
+ ):
164
+ super().__init__()
165
+ self.learn_sigma = learn_sigma
166
+ self.in_channels = in_channels
167
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
168
+ self.patch_size = patch_size
169
+ self.num_heads = num_heads
170
+
171
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
172
+ self.t_embedder = TimestepEmbedder(hidden_size)
173
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
174
+ num_patches = self.x_embedder.num_patches
175
+ # Will use fixed sin-cos embedding:
176
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
177
+
178
+ self.blocks = nn.ModuleList([
179
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
180
+ ])
181
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
182
+ self.initialize_weights()
183
+
184
+ def initialize_weights(self):
185
+ # Initialize transformer layers:
186
+ def _basic_init(module):
187
+ if isinstance(module, nn.Linear):
188
+ torch.nn.init.xavier_uniform_(module.weight)
189
+ if module.bias is not None:
190
+ nn.init.constant_(module.bias, 0)
191
+ self.apply(_basic_init)
192
+
193
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
194
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
195
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
196
+
197
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
198
+ w = self.x_embedder.proj.weight.data
199
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
200
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
201
+
202
+ # Initialize label embedding table:
203
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
204
+
205
+ # Initialize timestep embedding MLP:
206
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
207
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
208
+
209
+ # Zero-out adaLN modulation layers in DiT blocks:
210
+ for block in self.blocks:
211
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
212
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
213
+
214
+ # Zero-out output layers:
215
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
216
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
217
+ nn.init.constant_(self.final_layer.linear.weight, 0)
218
+ nn.init.constant_(self.final_layer.linear.bias, 0)
219
+
220
+ def unpatchify(self, x):
221
+ """
222
+ x: (N, T, patch_size**2 * C)
223
+ imgs: (N, H, W, C)
224
+ """
225
+ c = self.out_channels
226
+ p = self.x_embedder.patch_size[0]
227
+ h = w = int(x.shape[1] ** 0.5)
228
+ assert h * w == x.shape[1]
229
+
230
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
231
+ x = torch.einsum('nhwpqc->nchpwq', x)
232
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
233
+ return imgs
234
+
235
+ def forward(self, x, t, y):
236
+ """
237
+ Forward pass of DiT.
238
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
239
+ t: (N,) tensor of diffusion timesteps
240
+ y: (N,) tensor of class labels
241
+ """
242
+ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
243
+ t = self.t_embedder(t) # (N, D)
244
+ y = self.y_embedder(y, self.training) # (N, D)
245
+ c = t + y # (N, D)
246
+ for block in self.blocks:
247
+ x = block(x, c) # (N, T, D)
248
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
249
+ x = self.unpatchify(x) # (N, out_channels, H, W)
250
+ return x
251
+
252
+ def forward_with_cfg(self, x, t, y, cfg_scale):
253
+ """
254
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
255
+ """
256
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
257
+ half = x[: len(x) // 2]
258
+ combined = torch.cat([half, half], dim=0)
259
+ model_out = self.forward(combined, t, y)
260
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
261
+ # three channels by default. The standard approach to cfg applies it to all channels.
262
+ # This can be done by uncommenting the following line and commenting-out the line following that.
263
+ eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
264
+ # eps, rest = model_out[:, :3], model_out[:, 3:]
265
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
266
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
267
+ eps = torch.cat([half_eps, half_eps], dim=0)
268
+ return torch.cat([eps, rest], dim=1)
269
+
270
+
271
+ #################################################################################
272
+ # Sine/Cosine Positional Embedding Functions #
273
+ #################################################################################
274
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
275
+
276
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
277
+ """
278
+ grid_size: int of the grid height and width
279
+ return:
280
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
281
+ """
282
+ grid_h = np.arange(grid_size, dtype=np.float32)
283
+ grid_w = np.arange(grid_size, dtype=np.float32)
284
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
285
+ grid = np.stack(grid, axis=0)
286
+
287
+ grid = grid.reshape([2, 1, grid_size, grid_size])
288
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
289
+ if cls_token and extra_tokens > 0:
290
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
291
+ return pos_embed
292
+
293
+
294
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
295
+ assert embed_dim % 2 == 0
296
+
297
+ # use half of dimensions to encode grid_h
298
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
299
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
300
+
301
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
302
+ return emb
303
+
304
+
305
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
306
+ """
307
+ embed_dim: output dimension for each position
308
+ pos: a list of positions to be encoded: size (M,)
309
+ out: (M, D)
310
+ """
311
+ assert embed_dim % 2 == 0
312
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
313
+ omega /= embed_dim / 2.
314
+ omega = 1. / 10000**omega # (D/2,)
315
+
316
+ pos = pos.reshape(-1) # (M,)
317
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
318
+
319
+ emb_sin = np.sin(out) # (M, D/2)
320
+ emb_cos = np.cos(out) # (M, D/2)
321
+
322
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
323
+ return emb
324
+
325
+
326
+ #################################################################################
327
+ # DiT Configs #
328
+ #################################################################################
329
+
330
+ def DiT_XL_2(**kwargs):
331
+ return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
332
+
333
+ def DiT_XL_4(**kwargs):
334
+ return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
335
+
336
+ def DiT_XL_8(**kwargs):
337
+ return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
338
+
339
+ def DiT_L_2(**kwargs):
340
+ return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
341
+
342
+ def DiT_L_4(**kwargs):
343
+ return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
344
+
345
+ def DiT_L_8(**kwargs):
346
+ return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
347
+
348
+ def DiT_B_2(**kwargs):
349
+ return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
350
+
351
+ def DiT_B_4(**kwargs):
352
+ return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
353
+
354
+ def DiT_B_8(**kwargs):
355
+ return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
356
+
357
+ def DiT_S_2(**kwargs):
358
+ return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
359
+
360
+ def DiT_S_4(**kwargs):
361
+ return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
362
+
363
+ def DiT_S_8(**kwargs):
364
+ return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
365
+
366
+
367
+ DiT_models = {
368
+ 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
369
+ 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
370
+ 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
371
+ 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
372
+ }
paintmind/stage1/fused_attention.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from typing import Type
6
+
7
+ class Attention(nn.Module):
8
+ def __init__(
9
+ self,
10
+ dim: int,
11
+ num_heads: int = 8,
12
+ qkv_bias: bool = False,
13
+ qk_norm: bool = False,
14
+ proj_bias: bool = True,
15
+ attn_drop: float = 0.,
16
+ proj_drop: float = 0.,
17
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
18
+ ) -> None:
19
+ super().__init__()
20
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
21
+ self.num_heads = num_heads
22
+ self.head_dim = dim // num_heads
23
+ self.scale = self.head_dim ** -0.5
24
+
25
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
26
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
27
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
28
+ self.attn_drop = nn.Dropout(attn_drop)
29
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
30
+ self.proj_drop = nn.Dropout(proj_drop)
31
+
32
+ def forward(self, x, attn_mask=None):
33
+ B, N, C = x.shape
34
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
35
+ q, k, v = qkv.unbind(0)
36
+ q, k = self.q_norm(q), self.k_norm(k)
37
+
38
+ x = F.scaled_dot_product_attention(
39
+ q, k, v,
40
+ attn_mask=attn_mask, # True indicates parts should take part in attention in this API
41
+ dropout_p=self.attn_drop.p if self.training else 0.,
42
+ )
43
+
44
+ x = x.transpose(1, 2).reshape(B, N, C)
45
+ x = self.proj(x)
46
+ x = self.proj_drop(x)
47
+ return x
48
+
49
+
50
+ class MultiHeadCrossAttention(nn.Module):
51
+ def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
52
+ super().__init__()
53
+ if d_model % num_heads != 0:
54
+ raise AssertionError(
55
+ "d_model (%d) must be divisible by num_heads (%d)"
56
+ % (d_model, num_heads)
57
+ )
58
+
59
+ self.d_model = d_model
60
+ self.num_heads = num_heads
61
+ self.head_dim = d_model // num_heads
62
+
63
+ self.q_linear = nn.Linear(d_model, d_model)
64
+ self.kv_linear = nn.Linear(d_model, d_model * 2)
65
+ self.attn_drop = nn.Dropout(attn_drop)
66
+ self.proj = nn.Linear(d_model, d_model)
67
+ self.proj_drop = nn.Dropout(proj_drop)
68
+
69
+ def forward(self, x, cond, mask=None):
70
+ # query/value: img tokens; key: condition; mask: if padding tokens
71
+ B, N, C = x.shape
72
+
73
+ q = self.q_linear(x).view(-1, self.num_heads, self.head_dim)
74
+ kv = self.kv_linear(cond).view(-1, 2, self.num_heads, self.head_dim)
75
+ k, v = kv.unbind(1)
76
+
77
+ q = q.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
78
+ k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
79
+ v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
80
+ if mask is not None:
81
+ temp_mask = torch.ones(B, 1, q.size(-2), k.size(-2), dtype=torch.bool, device=q.device)
82
+ for i in range(B):
83
+ temp_mask[i, :, :, mask[i]:] = False
84
+ mask = temp_mask
85
+ x = F.scaled_dot_product_attention(
86
+ q, k, v,
87
+ attn_mask=mask,
88
+ dropout_p=self.attn_drop.p if self.training else 0.,
89
+ ).transpose(1, 2)
90
+
91
+ x = x.view(B, -1, C)
92
+ x = self.proj(x)
93
+ x = self.proj_drop(x)
94
+ return x
paintmind/stage1/pos_embed.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # Position embedding utils
8
+ # --------------------------------------------------------
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+
14
+ # --------------------------------------------------------
15
+ # 2D sine-cosine position embedding
16
+ # References:
17
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
18
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
19
+ # --------------------------------------------------------
20
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21
+ """
22
+ grid_size: int of the grid height and width
23
+ return:
24
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25
+ """
26
+ grid_h = np.arange(grid_size, dtype=np.float32)
27
+ grid_w = np.arange(grid_size, dtype=np.float32)
28
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
29
+ grid = np.stack(grid, axis=0)
30
+
31
+ grid = grid.reshape([2, 1, grid_size, grid_size])
32
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33
+ if cls_token:
34
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35
+ return pos_embed
36
+
37
+
38
+ def get_1d_sincos_pos_embed(embed_dim, grid_size):
39
+ grid = np.arange(grid_size, dtype=np.float32)
40
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
41
+ return pos_embed
42
+
43
+
44
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
45
+ assert embed_dim % 2 == 0
46
+
47
+ # use half of dimensions to encode grid_h
48
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
49
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
50
+
51
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
52
+ return emb
53
+
54
+
55
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
56
+ """
57
+ embed_dim: output dimension for each position
58
+ pos: a list of positions to be encoded: size (M,)
59
+ out: (M, D)
60
+ """
61
+ assert embed_dim % 2 == 0
62
+ omega = np.arange(embed_dim // 2, dtype=float)
63
+ omega /= embed_dim / 2.
64
+ omega = 1. / 10000**omega # (D/2,)
65
+
66
+ pos = pos.reshape(-1) # (M,)
67
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
68
+
69
+ emb_sin = np.sin(out) # (M, D/2)
70
+ emb_cos = np.cos(out) # (M, D/2)
71
+
72
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
73
+ return emb
74
+
75
+
76
+ # --------------------------------------------------------
77
+ # Interpolate position embeddings for high-resolution
78
+ # References:
79
+ # DeiT: https://github.com/facebookresearch/deit
80
+ # --------------------------------------------------------
81
+ def interpolate_pos_embed(model, checkpoint_model):
82
+ if 'pos_embed' in checkpoint_model:
83
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
84
+ embedding_size = pos_embed_checkpoint.shape[-1]
85
+ num_patches = model.patch_embed.num_patches
86
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
87
+ # height (== width) for the checkpoint position embedding
88
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
89
+ # height (== width) for the new position embedding
90
+ new_size = int(num_patches ** 0.5)
91
+ # class_token and dist_token are kept unchanged
92
+ if orig_size != new_size:
93
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
94
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
95
+ # only the position tokens are interpolated
96
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
97
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
98
+ pos_tokens = torch.nn.functional.interpolate(
99
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
100
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
101
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
102
+ checkpoint_model['pos_embed'] = new_pos_embed
paintmind/stage1/quantize.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+
6
+ def l2norm(t):
7
+ return F.normalize(t, p = 2, dim = -1)
8
+
9
+ class VectorQuantizer(nn.Module):
10
+ def __init__(self, n_e, e_dim, beta=0.25, use_norm=True):
11
+ super().__init__()
12
+ self.n_e = n_e
13
+ self.e_dim = e_dim
14
+ self.beta = beta
15
+ self.use_norm = use_norm
16
+
17
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
18
+ self.embedding.weight.data.normal_()
19
+
20
+ def forward(self, z):
21
+ if self.use_norm:
22
+ z = l2norm(z)
23
+ z_flattened = z.view(-1, self.e_dim)
24
+ if self.use_norm:
25
+ embedd_norm = l2norm(self.embedding.weight)
26
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
27
+
28
+ if self.use_norm:
29
+ d = 2 - 2 * torch.einsum('bc, nc -> bn', z_flattened, embedd_norm)
30
+ else:
31
+ d = torch.sum(z_flattened**2, dim=1, keepdim=True) + \
32
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
33
+ torch.einsum('bd,nd->bn', z_flattened, self.embedding.weight)
34
+
35
+ encoding_indices = torch.argmin(d, dim=1).view(*z.shape[:-1])
36
+ z_q = self.embedding(encoding_indices)
37
+ if self.use_norm:
38
+ z_q = l2norm(z_q)
39
+
40
+ # compute loss for embedding
41
+ loss = self.beta * torch.mean((z_q.detach()-z)**2) + torch.mean((z_q-z.detach())**2)
42
+
43
+ # preserve gradients
44
+ z_q = z + (z_q - z).detach()
45
+
46
+ return z_q, loss, encoding_indices
47
+
48
+ def decode_from_indice(self, indices):
49
+ z_q = self.embedding(indices)
50
+ if self.use_norm:
51
+ z_q = l2norm(z_q)
52
+
53
+ return z_q
54
+
55
+ class DiagonalGaussianDistribution(object):
56
+ def __init__(self, parameters, deterministic=False):
57
+ self.parameters = parameters # [B, L, 2C], not [B, 2C, H, W]
58
+ self.mean, self.logvar = torch.chunk(parameters, 2, dim=2)
59
+ self.logvar = torch.clamp(self.logvar, -30.0, 20.0)
60
+ self.deterministic = deterministic
61
+ self.std = torch.exp(0.5 * self.logvar)
62
+ self.var = torch.exp(self.logvar)
63
+ if self.deterministic:
64
+ self.var = self.std = torch.zeros_like(self.mean).to(device=self.parameters.device)
65
+
66
+ def sample(self):
67
+ x = self.mean + self.std * torch.randn(self.mean.shape).to(device=self.parameters.device)
68
+ return x
69
+
70
+ def kl(self, other=None):
71
+ if self.deterministic:
72
+ return torch.Tensor([0.])
73
+ else:
74
+ if other is None:
75
+ return 0.5 * torch.sum(torch.pow(self.mean, 2)
76
+ + self.var - 1.0 - self.logvar,
77
+ dim=[1, 2])
78
+ else:
79
+ return 0.5 * torch.sum(
80
+ torch.pow(self.mean - other.mean, 2) / other.var
81
+ + self.var / other.var - 1.0 - self.logvar + other.logvar,
82
+ dim=[1, 2])
83
+
84
+ def nll(self, sample, dims=[1,2]):
85
+ if self.deterministic:
86
+ return torch.Tensor([0.])
87
+ logtwopi = np.log(2.0 * np.pi)
88
+ return 0.5 * torch.sum(
89
+ logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var,
90
+ dim=dims)
91
+
92
+ def mode(self):
93
+ return self.mean
paintmind/stage1/transport/__init__.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .transport import Transport, ModelType, WeightType, PathType, Sampler
2
+
3
+ def create_transport(
4
+ path_type='Linear',
5
+ prediction="velocity",
6
+ loss_weight=None,
7
+ train_eps=None,
8
+ sample_eps=None,
9
+ ):
10
+ """function for creating Transport object
11
+ **Note**: model prediction defaults to velocity
12
+ Args:
13
+ - path_type: type of path to use; default to linear
14
+ - learn_score: set model prediction to score
15
+ - learn_noise: set model prediction to noise
16
+ - velocity_weighted: weight loss by velocity weight
17
+ - likelihood_weighted: weight loss by likelihood weight
18
+ - train_eps: small epsilon for avoiding instability during training
19
+ - sample_eps: small epsilon for avoiding instability during sampling
20
+ """
21
+
22
+ if prediction == "noise":
23
+ model_type = ModelType.NOISE
24
+ elif prediction == "score":
25
+ model_type = ModelType.SCORE
26
+ else:
27
+ model_type = ModelType.VELOCITY
28
+
29
+ if loss_weight == "velocity":
30
+ loss_type = WeightType.VELOCITY
31
+ elif loss_weight == "likelihood":
32
+ loss_type = WeightType.LIKELIHOOD
33
+ else:
34
+ loss_type = WeightType.NONE
35
+
36
+ path_choice = {
37
+ "Linear": PathType.LINEAR,
38
+ "GVP": PathType.GVP,
39
+ "VP": PathType.VP,
40
+ }
41
+
42
+ path_type = path_choice[path_type]
43
+
44
+ if (path_type in [PathType.VP]):
45
+ train_eps = 1e-5 if train_eps is None else train_eps
46
+ sample_eps = 1e-3 if train_eps is None else sample_eps
47
+ elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY):
48
+ train_eps = 1e-3 if train_eps is None else train_eps
49
+ sample_eps = 1e-3 if train_eps is None else sample_eps
50
+ else: # velocity & [GVP, LINEAR] is stable everywhere
51
+ train_eps = 0
52
+ sample_eps = 0
53
+
54
+ # create flow state
55
+ state = Transport(
56
+ model_type=model_type,
57
+ path_type=path_type,
58
+ loss_type=loss_type,
59
+ train_eps=train_eps,
60
+ sample_eps=sample_eps,
61
+ )
62
+
63
+ return state
paintmind/stage1/transport/integrators.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as th
3
+ import torch.nn as nn
4
+ from torchdiffeq import odeint
5
+ from functools import partial
6
+ from tqdm import tqdm
7
+
8
+ class sde:
9
+ """SDE solver class"""
10
+ def __init__(
11
+ self,
12
+ drift,
13
+ diffusion,
14
+ *,
15
+ t0,
16
+ t1,
17
+ num_steps,
18
+ sampler_type,
19
+ temperature=1.0,
20
+ ):
21
+ assert t0 < t1, "SDE sampler has to be in forward time"
22
+
23
+ self.num_timesteps = num_steps
24
+ self.t = th.linspace(t0, t1, num_steps)
25
+ self.dt = self.t[1] - self.t[0]
26
+ self.drift = drift
27
+ self.diffusion = diffusion
28
+ self.sampler_type = sampler_type
29
+ self.temperature = temperature
30
+
31
+ def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
32
+ w_cur = th.randn(x.size()).to(x)
33
+ t = th.ones(x.size(0)).to(x) * t
34
+ dw = w_cur * th.sqrt(self.dt)
35
+ drift = self.drift(x, t, model, **model_kwargs)
36
+ diffusion = self.diffusion(x, t)
37
+ mean_x = x + drift * self.dt
38
+ x = mean_x + th.sqrt(2 * diffusion) * dw * self.temperature
39
+ return x, mean_x
40
+
41
+ def __Heun_step(self, x, _, t, model, **model_kwargs):
42
+ w_cur = th.randn(x.size()).to(x)
43
+ dw = w_cur * th.sqrt(self.dt) * self.temperature
44
+ t_cur = th.ones(x.size(0)).to(x) * t
45
+ diffusion = self.diffusion(x, t_cur)
46
+ xhat = x + th.sqrt(2 * diffusion) * dw
47
+ K1 = self.drift(xhat, t_cur, model, **model_kwargs)
48
+ xp = xhat + self.dt * K1
49
+ K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
50
+ return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step
51
+
52
+ def __forward_fn(self):
53
+ """TODO: generalize here by adding all private functions ending with steps to it"""
54
+ sampler_dict = {
55
+ "Euler": self.__Euler_Maruyama_step,
56
+ "Heun": self.__Heun_step,
57
+ }
58
+
59
+ try:
60
+ sampler = sampler_dict[self.sampler_type]
61
+ except:
62
+ raise NotImplementedError("Smapler type not implemented.")
63
+
64
+ return sampler
65
+
66
+ def sample(self, init, model, **model_kwargs):
67
+ """forward loop of sde"""
68
+ x = init
69
+ mean_x = init
70
+ samples = []
71
+ sampler = self.__forward_fn()
72
+ for ti in self.t[:-1]:
73
+ with th.no_grad():
74
+ x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
75
+ samples.append(x)
76
+
77
+ return samples
78
+
79
+ class ode:
80
+ """ODE solver class"""
81
+ def __init__(
82
+ self,
83
+ drift,
84
+ *,
85
+ t0,
86
+ t1,
87
+ sampler_type,
88
+ num_steps,
89
+ atol,
90
+ rtol,
91
+ temperature=1.0,
92
+ ):
93
+ assert t0 < t1, "ODE sampler has to be in forward time"
94
+
95
+ self.drift = drift
96
+ self.t = th.linspace(t0, t1, num_steps)
97
+ self.atol = atol
98
+ self.rtol = rtol
99
+ self.sampler_type = sampler_type
100
+ self.temperature = temperature
101
+
102
+ def sample(self, x, model, **model_kwargs):
103
+
104
+ device = x[0].device if isinstance(x, tuple) else x.device
105
+ def _fn(t, x):
106
+ t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
107
+ # For ODE, we scale the drift by the temperature
108
+ # This is equivalent to scaling time by 1/temperature
109
+ model_output = self.drift(x, t, model, **model_kwargs)
110
+ if self.temperature != 1.0:
111
+ # If it's a tuple (for likelihood calculation), only scale the first element
112
+ if isinstance(model_output, tuple):
113
+ scaled_output = (model_output[0] / self.temperature, model_output[1])
114
+ return scaled_output
115
+ else:
116
+ return model_output / self.temperature
117
+ return model_output
118
+
119
+ t = self.t.to(device)
120
+ atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
121
+ rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
122
+ samples = odeint(
123
+ _fn,
124
+ x,
125
+ t,
126
+ method=self.sampler_type,
127
+ atol=atol,
128
+ rtol=rtol
129
+ )
130
+ return samples
paintmind/stage1/transport/path.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ import numpy as np
3
+ from functools import partial
4
+
5
+ def expand_t_like_x(t, x):
6
+ """Function to reshape time t to broadcastable dimension of x
7
+ Args:
8
+ t: [batch_dim,], time vector
9
+ x: [batch_dim,...], data point
10
+ """
11
+ dims = [1] * (len(x.size()) - 1)
12
+ t = t.view(t.size(0), *dims)
13
+ return t
14
+
15
+
16
+ #################### Coupling Plans ####################
17
+
18
+ class ICPlan:
19
+ """Linear Coupling Plan"""
20
+ def __init__(self, sigma=0.0):
21
+ self.sigma = sigma
22
+
23
+ def compute_alpha_t(self, t):
24
+ """Compute the data coefficient along the path"""
25
+ return t, 1
26
+
27
+ def compute_sigma_t(self, t):
28
+ """Compute the noise coefficient along the path"""
29
+ return 1 - t, -1
30
+
31
+ def compute_d_alpha_alpha_ratio_t(self, t):
32
+ """Compute the ratio between d_alpha and alpha"""
33
+ return 1 / t
34
+
35
+ def compute_drift(self, x, t):
36
+ """We always output sde according to score parametrization; """
37
+ t = expand_t_like_x(t, x)
38
+ alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
39
+ sigma_t, d_sigma_t = self.compute_sigma_t(t)
40
+ drift = alpha_ratio * x
41
+ diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t
42
+
43
+ return -drift, diffusion
44
+
45
+ def compute_diffusion(self, x, t, form="constant", norm=1.0):
46
+ """Compute the diffusion term of the SDE
47
+ Args:
48
+ x: [batch_dim, ...], data point
49
+ t: [batch_dim,], time vector
50
+ form: str, form of the diffusion term
51
+ norm: float, norm of the diffusion term
52
+ """
53
+ t = expand_t_like_x(t, x)
54
+ choices = {
55
+ "constant": norm,
56
+ "SBDM": norm * self.compute_drift(x, t)[1],
57
+ "sigma": norm * self.compute_sigma_t(t)[0],
58
+ "linear": norm * (1 - t),
59
+ "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
60
+ "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
61
+ }
62
+
63
+ try:
64
+ diffusion = choices[form]
65
+ except KeyError:
66
+ raise NotImplementedError(f"Diffusion form {form} not implemented")
67
+
68
+ return diffusion
69
+
70
+ def get_score_from_velocity(self, velocity, x, t):
71
+ """Wrapper function: transfrom velocity prediction model to score
72
+ Args:
73
+ velocity: [batch_dim, ...] shaped tensor; velocity model output
74
+ x: [batch_dim, ...] shaped tensor; x_t data point
75
+ t: [batch_dim,] time tensor
76
+ """
77
+ t = expand_t_like_x(t, x)
78
+ alpha_t, d_alpha_t = self.compute_alpha_t(t)
79
+ sigma_t, d_sigma_t = self.compute_sigma_t(t)
80
+ mean = x
81
+ reverse_alpha_ratio = alpha_t / d_alpha_t
82
+ var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
83
+ score = (reverse_alpha_ratio * velocity - mean) / var
84
+ return score
85
+
86
+ def get_noise_from_velocity(self, velocity, x, t):
87
+ """Wrapper function: transfrom velocity prediction model to denoiser
88
+ Args:
89
+ velocity: [batch_dim, ...] shaped tensor; velocity model output
90
+ x: [batch_dim, ...] shaped tensor; x_t data point
91
+ t: [batch_dim,] time tensor
92
+ """
93
+ t = expand_t_like_x(t, x)
94
+ alpha_t, d_alpha_t = self.compute_alpha_t(t)
95
+ sigma_t, d_sigma_t = self.compute_sigma_t(t)
96
+ mean = x
97
+ reverse_alpha_ratio = alpha_t / d_alpha_t
98
+ var = reverse_alpha_ratio * d_sigma_t - sigma_t
99
+ noise = (reverse_alpha_ratio * velocity - mean) / var
100
+ return noise
101
+
102
+ def get_velocity_from_score(self, score, x, t):
103
+ """Wrapper function: transfrom score prediction model to velocity
104
+ Args:
105
+ score: [batch_dim, ...] shaped tensor; score model output
106
+ x: [batch_dim, ...] shaped tensor; x_t data point
107
+ t: [batch_dim,] time tensor
108
+ """
109
+ t = expand_t_like_x(t, x)
110
+ drift, var = self.compute_drift(x, t)
111
+ velocity = var * score - drift
112
+ return velocity
113
+
114
+ def compute_mu_t(self, t, x0, x1):
115
+ """Compute the mean of time-dependent density p_t"""
116
+ t = expand_t_like_x(t, x1)
117
+ alpha_t, _ = self.compute_alpha_t(t)
118
+ sigma_t, _ = self.compute_sigma_t(t)
119
+ return alpha_t * x1 + sigma_t * x0
120
+
121
+ def compute_xt(self, t, x0, x1):
122
+ """Sample xt from time-dependent density p_t; rng is required"""
123
+ xt = self.compute_mu_t(t, x0, x1)
124
+ return xt
125
+
126
+ def compute_ut(self, t, x0, x1, xt):
127
+ """Compute the vector field corresponding to p_t"""
128
+ t = expand_t_like_x(t, x1)
129
+ _, d_alpha_t = self.compute_alpha_t(t)
130
+ _, d_sigma_t = self.compute_sigma_t(t)
131
+ return d_alpha_t * x1 + d_sigma_t * x0
132
+
133
+ def plan(self, t, x0, x1):
134
+ xt = self.compute_xt(t, x0, x1)
135
+ ut = self.compute_ut(t, x0, x1, xt)
136
+ return t, xt, ut
137
+
138
+
139
+ class VPCPlan(ICPlan):
140
+ """class for VP path flow matching"""
141
+
142
+ def __init__(self, sigma_min=0.1, sigma_max=20.0):
143
+ self.sigma_min = sigma_min
144
+ self.sigma_max = sigma_max
145
+ self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
146
+ self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
147
+
148
+
149
+ def compute_alpha_t(self, t):
150
+ """Compute coefficient of x1"""
151
+ alpha_t = self.log_mean_coeff(t)
152
+ alpha_t = th.exp(alpha_t)
153
+ d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
154
+ return alpha_t, d_alpha_t
155
+
156
+ def compute_sigma_t(self, t):
157
+ """Compute coefficient of x0"""
158
+ p_sigma_t = 2 * self.log_mean_coeff(t)
159
+ sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
160
+ d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
161
+ return sigma_t, d_sigma_t
162
+
163
+ def compute_d_alpha_alpha_ratio_t(self, t):
164
+ """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
165
+ return self.d_log_mean_coeff(t)
166
+
167
+ def compute_drift(self, x, t):
168
+ """Compute the drift term of the SDE"""
169
+ t = expand_t_like_x(t, x)
170
+ beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
171
+ return -0.5 * beta_t * x, beta_t / 2
172
+
173
+
174
+ class GVPCPlan(ICPlan):
175
+ def __init__(self, sigma=0.0):
176
+ super().__init__(sigma)
177
+
178
+ def compute_alpha_t(self, t):
179
+ """Compute coefficient of x1"""
180
+ alpha_t = th.sin(t * np.pi / 2)
181
+ d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
182
+ return alpha_t, d_alpha_t
183
+
184
+ def compute_sigma_t(self, t):
185
+ """Compute coefficient of x0"""
186
+ sigma_t = th.cos(t * np.pi / 2)
187
+ d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
188
+ return sigma_t, d_sigma_t
189
+
190
+ def compute_d_alpha_alpha_ratio_t(self, t):
191
+ """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
192
+ return np.pi / (2 * th.tan(t * np.pi / 2))
paintmind/stage1/transport/transport.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ import numpy as np
3
+ import logging
4
+
5
+ import enum
6
+
7
+ from . import path
8
+ from .utils import EasyDict, log_state, mean_flat
9
+ from .integrators import ode, sde
10
+
11
+ class ModelType(enum.Enum):
12
+ """
13
+ Which type of output the model predicts.
14
+ """
15
+
16
+ NOISE = enum.auto() # the model predicts epsilon
17
+ SCORE = enum.auto() # the model predicts \nabla \log p(x)
18
+ VELOCITY = enum.auto() # the model predicts v(x)
19
+
20
+ class PathType(enum.Enum):
21
+ """
22
+ Which type of path to use.
23
+ """
24
+
25
+ LINEAR = enum.auto()
26
+ GVP = enum.auto()
27
+ VP = enum.auto()
28
+
29
+ class WeightType(enum.Enum):
30
+ """
31
+ Which type of weighting to use.
32
+ """
33
+
34
+ NONE = enum.auto()
35
+ VELOCITY = enum.auto()
36
+ LIKELIHOOD = enum.auto()
37
+
38
+
39
+ class Transport:
40
+
41
+ def __init__(
42
+ self,
43
+ *,
44
+ model_type,
45
+ path_type,
46
+ loss_type,
47
+ train_eps,
48
+ sample_eps,
49
+ ):
50
+ path_options = {
51
+ PathType.LINEAR: path.ICPlan,
52
+ PathType.GVP: path.GVPCPlan,
53
+ PathType.VP: path.VPCPlan,
54
+ }
55
+
56
+ self.loss_type = loss_type
57
+ self.model_type = model_type
58
+ self.path_sampler = path_options[path_type]()
59
+ self.train_eps = train_eps
60
+ self.sample_eps = sample_eps
61
+
62
+ def prior_logp(self, z):
63
+ '''
64
+ Standard multivariate normal prior
65
+ Assume z is batched
66
+ '''
67
+ shape = th.tensor(z.size())
68
+ N = th.prod(shape[1:])
69
+ _fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2.
70
+ return th.vmap(_fn)(z)
71
+
72
+
73
+ def check_interval(
74
+ self,
75
+ train_eps,
76
+ sample_eps,
77
+ *,
78
+ diffusion_form="SBDM",
79
+ sde=False,
80
+ reverse=False,
81
+ eval=False,
82
+ last_step_size=0.0,
83
+ ):
84
+ t0 = 0
85
+ t1 = 1
86
+ eps = train_eps if not eval else sample_eps
87
+ if (type(self.path_sampler) in [path.VPCPlan]):
88
+
89
+ t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
90
+
91
+ elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \
92
+ and (self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step
93
+
94
+ t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
95
+ t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
96
+
97
+ if reverse:
98
+ t0, t1 = 1 - t0, 1 - t1
99
+
100
+ return t0, t1
101
+
102
+
103
+ def sample(self, x1):
104
+ """Sampling x0 & t based on shape of x1 (if needed)
105
+ Args:
106
+ x1 - data point; [batch, *dim]
107
+ """
108
+
109
+ x0 = th.randn_like(x1)
110
+ t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
111
+ t = th.rand((x1.shape[0],)) * (t1 - t0) + t0
112
+ t = t.to(x1)
113
+ return t, x0, x1
114
+
115
+
116
+ def training_losses(
117
+ self,
118
+ model,
119
+ x1,
120
+ model_kwargs=None
121
+ ):
122
+ """Loss for training the score model
123
+ Args:
124
+ - model: backbone model; could be score, noise, or velocity
125
+ - x1: datapoint
126
+ - model_kwargs: additional arguments for the model
127
+ """
128
+ if model_kwargs == None:
129
+ model_kwargs = {}
130
+
131
+ t, x0, x1 = self.sample(x1)
132
+ t, xt, ut = self.path_sampler.plan(t, x0, x1)
133
+ model_output = model(xt, t, **model_kwargs)
134
+ if len(model_output.shape) == len(xt.shape) + 1:
135
+ x0 = x0.unsqueeze(-1).expand(*([-1] * (len(x0.shape))), model_output.shape[-1])
136
+ xt = xt.unsqueeze(-1).expand(*([-1] * (len(xt.shape))), model_output.shape[-1])
137
+ ut = ut.unsqueeze(-1).expand(*([-1] * (len(ut.shape))), model_output.shape[-1])
138
+ B, C = xt.shape[:2]
139
+ assert model_output.shape == (B, C, *xt.shape[2:])
140
+
141
+ terms = {}
142
+ terms['pred'] = model_output
143
+ if self.model_type == ModelType.VELOCITY:
144
+ terms['loss'] = mean_flat(((model_output - ut) ** 2))
145
+ else:
146
+ _, drift_var = self.path_sampler.compute_drift(xt, t)
147
+ sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt))
148
+ if self.loss_type in [WeightType.VELOCITY]:
149
+ weight = (drift_var / sigma_t) ** 2
150
+ elif self.loss_type in [WeightType.LIKELIHOOD]:
151
+ weight = drift_var / (sigma_t ** 2)
152
+ elif self.loss_type in [WeightType.NONE]:
153
+ weight = 1
154
+ else:
155
+ raise NotImplementedError()
156
+
157
+ if self.model_type == ModelType.NOISE:
158
+ terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2))
159
+ else:
160
+ terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2))
161
+
162
+ return terms
163
+
164
+
165
+ def get_drift(
166
+ self
167
+ ):
168
+ """member function for obtaining the drift of the probability flow ODE"""
169
+ def score_ode(x, t, model, **model_kwargs):
170
+ drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
171
+ model_output = model(x, t, **model_kwargs)
172
+ return (-drift_mean + drift_var * model_output) # by change of variable
173
+
174
+ def noise_ode(x, t, model, **model_kwargs):
175
+ drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
176
+ sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
177
+ model_output = model(x, t, **model_kwargs)
178
+ score = model_output / -sigma_t
179
+ return (-drift_mean + drift_var * score)
180
+
181
+ def velocity_ode(x, t, model, **model_kwargs):
182
+ model_output = model(x, t, **model_kwargs)
183
+ return model_output
184
+
185
+ if self.model_type == ModelType.NOISE:
186
+ drift_fn = noise_ode
187
+ elif self.model_type == ModelType.SCORE:
188
+ drift_fn = score_ode
189
+ else:
190
+ drift_fn = velocity_ode
191
+
192
+ def body_fn(x, t, model, **model_kwargs):
193
+ model_output = drift_fn(x, t, model, **model_kwargs)
194
+ assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
195
+ return model_output
196
+
197
+ return body_fn
198
+
199
+
200
+ def get_score(
201
+ self,
202
+ ):
203
+ """member function for obtaining score of
204
+ x_t = alpha_t * x + sigma_t * eps"""
205
+ if self.model_type == ModelType.NOISE:
206
+ score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
207
+ elif self.model_type == ModelType.SCORE:
208
+ score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
209
+ elif self.model_type == ModelType.VELOCITY:
210
+ score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t)
211
+ else:
212
+ raise NotImplementedError()
213
+
214
+ return score_fn
215
+
216
+
217
+ class Sampler:
218
+ """Sampler class for the transport model"""
219
+ def __init__(
220
+ self,
221
+ transport,
222
+ ):
223
+ """Constructor for a general sampler; supporting different sampling methods
224
+ Args:
225
+ - transport: an tranport object specify model prediction & interpolant type
226
+ """
227
+
228
+ self.transport = transport
229
+ self.drift = self.transport.get_drift()
230
+ self.score = self.transport.get_score()
231
+
232
+ def __get_sde_diffusion_and_drift(
233
+ self,
234
+ *,
235
+ diffusion_form="SBDM",
236
+ diffusion_norm=1.0,
237
+ ):
238
+
239
+ def diffusion_fn(x, t):
240
+ diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)
241
+ return diffusion
242
+
243
+ sde_drift = \
244
+ lambda x, t, model, **kwargs: \
245
+ self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
246
+
247
+ sde_diffusion = diffusion_fn
248
+
249
+ return sde_drift, sde_diffusion
250
+
251
+ def __get_last_step(
252
+ self,
253
+ sde_drift,
254
+ *,
255
+ last_step,
256
+ last_step_size,
257
+ ):
258
+ """Get the last step function of the SDE solver"""
259
+
260
+ if last_step is None:
261
+ last_step_fn = \
262
+ lambda x, t, model, **model_kwargs: \
263
+ x
264
+ elif last_step == "Mean":
265
+ last_step_fn = \
266
+ lambda x, t, model, **model_kwargs: \
267
+ x + sde_drift(x, t, model, **model_kwargs) * last_step_size
268
+ elif last_step == "Tweedie":
269
+ alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long
270
+ sigma = self.transport.path_sampler.compute_sigma_t
271
+ last_step_fn = \
272
+ lambda x, t, model, **model_kwargs: \
273
+ x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs)
274
+ elif last_step == "Euler":
275
+ last_step_fn = \
276
+ lambda x, t, model, **model_kwargs: \
277
+ x + self.drift(x, t, model, **model_kwargs) * last_step_size
278
+ else:
279
+ raise NotImplementedError()
280
+
281
+ return last_step_fn
282
+
283
+ def sample_sde(
284
+ self,
285
+ *,
286
+ sampling_method="Euler",
287
+ diffusion_form="SBDM",
288
+ diffusion_norm=1.0,
289
+ last_step="Mean",
290
+ last_step_size=0.04,
291
+ num_steps=250,
292
+ temperature=1.0,
293
+ ):
294
+ """returns a sampling function with given SDE settings
295
+ Args:
296
+ - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
297
+ - diffusion_form: function form of diffusion coefficient; default to be matching SBDM
298
+ - diffusion_norm: function magnitude of diffusion coefficient; default to 1
299
+ - last_step: type of the last step; default to identity
300
+ - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
301
+ - num_steps: total integration step of SDE
302
+ - temperature: temperature scaling for the noise during sampling; default to 1.0
303
+ """
304
+
305
+ if last_step is None:
306
+ last_step_size = 0.0
307
+
308
+ sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
309
+ diffusion_form=diffusion_form,
310
+ diffusion_norm=diffusion_norm,
311
+ )
312
+
313
+ t0, t1 = self.transport.check_interval(
314
+ self.transport.train_eps,
315
+ self.transport.sample_eps,
316
+ diffusion_form=diffusion_form,
317
+ sde=True,
318
+ eval=True,
319
+ reverse=False,
320
+ last_step_size=last_step_size,
321
+ )
322
+
323
+ _sde = sde(
324
+ sde_drift,
325
+ sde_diffusion,
326
+ t0=t0,
327
+ t1=t1,
328
+ num_steps=num_steps,
329
+ sampler_type=sampling_method,
330
+ temperature=temperature
331
+ )
332
+
333
+ last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
334
+
335
+
336
+ def _sample(init, model, **model_kwargs):
337
+ xs = _sde.sample(init, model, **model_kwargs)
338
+ ts = th.ones(init.size(0), device=init.device) * t1
339
+ x = last_step_fn(xs[-1], ts, model, **model_kwargs)
340
+ xs.append(x)
341
+
342
+ assert len(xs) == num_steps, "Samples does not match the number of steps"
343
+
344
+ return xs
345
+
346
+ return _sample
347
+
348
+ def sample_ode(
349
+ self,
350
+ *,
351
+ sampling_method="dopri5",
352
+ num_steps=50,
353
+ atol=1e-6,
354
+ rtol=1e-3,
355
+ reverse=False,
356
+ temperature=1.0,
357
+ ):
358
+ """returns a sampling function with given ODE settings
359
+ Args:
360
+ - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
361
+ - num_steps:
362
+ - fixed solver (Euler, Heun): the actual number of integration steps performed
363
+ - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
364
+ - atol: absolute error tolerance for the solver
365
+ - rtol: relative error tolerance for the solver
366
+ - reverse: whether solving the ODE in reverse (data to noise); default to False
367
+ - temperature: temperature scaling for the drift during sampling; default to 1.0
368
+ """
369
+ if reverse:
370
+ drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)
371
+ else:
372
+ drift = self.drift
373
+
374
+ t0, t1 = self.transport.check_interval(
375
+ self.transport.train_eps,
376
+ self.transport.sample_eps,
377
+ sde=False,
378
+ eval=True,
379
+ reverse=reverse,
380
+ last_step_size=0.0,
381
+ )
382
+
383
+ _ode = ode(
384
+ drift=drift,
385
+ t0=t0,
386
+ t1=t1,
387
+ sampler_type=sampling_method,
388
+ num_steps=num_steps,
389
+ atol=atol,
390
+ rtol=rtol,
391
+ temperature=temperature,
392
+ )
393
+
394
+ return _ode.sample
395
+
396
+ def sample_ode_likelihood(
397
+ self,
398
+ *,
399
+ sampling_method="dopri5",
400
+ num_steps=50,
401
+ atol=1e-6,
402
+ rtol=1e-3,
403
+ temperature=1.0,
404
+ ):
405
+
406
+ """returns a sampling function for calculating likelihood with given ODE settings
407
+ Args:
408
+ - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
409
+ - num_steps:
410
+ - fixed solver (Euler, Heun): the actual number of integration steps performed
411
+ - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
412
+ - atol: absolute error tolerance for the solver
413
+ - rtol: relative error tolerance for the solver
414
+ - temperature: temperature scaling for the drift during sampling; default to 1.0
415
+ """
416
+ def _likelihood_drift(x, t, model, **model_kwargs):
417
+ x, _ = x
418
+ eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
419
+ t = th.ones_like(t) * (1 - t)
420
+ with th.enable_grad():
421
+ x.requires_grad = True
422
+ grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
423
+ logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
424
+ drift = self.drift(x, t, model, **model_kwargs)
425
+ return (-drift, logp_grad)
426
+
427
+ t0, t1 = self.transport.check_interval(
428
+ self.transport.train_eps,
429
+ self.transport.sample_eps,
430
+ sde=False,
431
+ eval=True,
432
+ reverse=False,
433
+ last_step_size=0.0,
434
+ )
435
+
436
+ _ode = ode(
437
+ drift=_likelihood_drift,
438
+ t0=t0,
439
+ t1=t1,
440
+ sampler_type=sampling_method,
441
+ num_steps=num_steps,
442
+ atol=atol,
443
+ rtol=rtol,
444
+ temperature=temperature,
445
+ )
446
+
447
+ def _sample_fn(x, model, **model_kwargs):
448
+ init_logp = th.zeros(x.size(0)).to(x)
449
+ input = (x, init_logp)
450
+ drift, delta_logp = _ode.sample(input, model, **model_kwargs)
451
+ drift, delta_logp = drift[-1], delta_logp[-1]
452
+ prior_logp = self.transport.prior_logp(drift)
453
+ logp = prior_logp - delta_logp
454
+ return logp, drift
455
+
456
+ return _sample_fn
paintmind/stage1/transport/utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+
3
+ class EasyDict:
4
+
5
+ def __init__(self, sub_dict):
6
+ for k, v in sub_dict.items():
7
+ setattr(self, k, v)
8
+
9
+ def __getitem__(self, key):
10
+ return getattr(self, key)
11
+
12
+ def mean_flat(x):
13
+ """
14
+ Take the mean over all non-batch dimensions.
15
+ """
16
+ return th.mean(x, dim=list(range(1, len(x.size()))))
17
+
18
+ def log_state(state):
19
+ result = []
20
+
21
+ sorted_state = dict(sorted(state.items()))
22
+ for key, value in sorted_state.items():
23
+ # Check if the value is an instance of a class
24
+ if "<object" in str(value) or "object at" in str(value):
25
+ result.append(f"{key}: [{value.__class__.__name__}]")
26
+ else:
27
+ result.append(f"{key}: {value}")
28
+
29
+ return '\n'.join(result)
paintmind/stage1/vision_transformers.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) ByteDance, Inc. and its affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Mostly copy-paste from DINO and timm library:
9
+ https://github.com/facebookresearch/dino
10
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
11
+ """
12
+
13
+ import math
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ from functools import partial
19
+ from paintmind.stage1.fused_attention import Attention
20
+
21
+ __all__ = ['VisionTransformer', 'vit_tiny_patch16', 'vit_small_patch16',
22
+ 'vit_base_patch16', 'vit_large_patch16', 'vit_huge_patch14']
23
+
24
+
25
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
26
+ if drop_prob == 0. or not training:
27
+ return x
28
+ keep_prob = 1 - drop_prob
29
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
30
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
31
+ if keep_prob > 0.0:
32
+ random_tensor.div_(keep_prob)
33
+ return x * random_tensor
34
+
35
+
36
+ class DropPath(nn.Module):
37
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
38
+ """
39
+
40
+ def __init__(self, drop_prob=None):
41
+ super(DropPath, self).__init__()
42
+ self.drop_prob = drop_prob
43
+
44
+ def forward(self, x):
45
+ return drop_path(x, self.drop_prob, self.training)
46
+
47
+
48
+ class Mlp(nn.Module):
49
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
50
+ super().__init__()
51
+ out_features = out_features or in_features
52
+ hidden_features = hidden_features or in_features
53
+ self.fc1 = nn.Linear(in_features, hidden_features)
54
+ self.act = act_layer()
55
+ self.fc2 = nn.Linear(hidden_features, out_features)
56
+ self.drop = nn.Dropout(drop)
57
+
58
+ def forward(self, x):
59
+ x = self.fc1(x)
60
+ x = self.act(x)
61
+ x = self.drop(x)
62
+ x = self.fc2(x)
63
+ x = self.drop(x)
64
+ return x
65
+
66
+
67
+ class Block(nn.Module):
68
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0.,
69
+ attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, init_values=0):
70
+ super().__init__()
71
+ self.norm1 = norm_layer(dim)
72
+ self.attn = Attention(
73
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
74
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
75
+ self.norm2 = norm_layer(dim)
76
+ mlp_hidden_dim = int(dim * mlp_ratio)
77
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
78
+
79
+ if init_values > 0:
80
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
81
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
82
+ else:
83
+ self.gamma_1, self.gamma_2 = None, None
84
+
85
+ def forward(self, x, attn_mask=None):
86
+ y = self.attn(self.norm1(x), attn_mask=attn_mask)
87
+ if self.gamma_1 is None:
88
+ x = x + self.drop_path(y)
89
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
90
+ else:
91
+ x = x + self.drop_path(self.gamma_1 * y)
92
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
93
+ return x
94
+
95
+
96
+ class PatchEmbed(nn.Module):
97
+ """ Image to Patch Embedding
98
+ """
99
+
100
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
101
+ super().__init__()
102
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
103
+ self.img_size = img_size
104
+ self.patch_size = patch_size
105
+ self.num_patches = num_patches
106
+
107
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
108
+
109
+ def forward(self, x):
110
+ B, C, H, W = x.shape
111
+ return self.proj(x)
112
+
113
+
114
+ class VisionTransformer(nn.Module):
115
+ """ Vision Transformer """
116
+
117
+ def __init__(self, img_size=[224], patch_size=16, in_chans=3, embed_dim=768, depth=12,
118
+ num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.,
119
+ drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
120
+ init_values=0, num_slots=16, slot_through=True):
121
+ super().__init__()
122
+ self.num_features = self.embed_dim = embed_dim
123
+
124
+ self.patch_embed = PatchEmbed(
125
+ img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
126
+ num_patches = self.patch_embed.num_patches
127
+
128
+ self.num_slots = num_slots if slot_through else 0
129
+ self.slot_through = slot_through
130
+
131
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
132
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1 + self.num_slots, embed_dim))
133
+ if self.slot_through:
134
+ self.slot_embed = nn.Parameter(torch.zeros(1, num_slots, embed_dim))
135
+
136
+ self.pos_drop = nn.Dropout(p=drop_rate)
137
+
138
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
139
+ self.blocks = nn.ModuleList([
140
+ Block(
141
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
142
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
143
+ init_values=init_values)
144
+ for i in range(depth)])
145
+
146
+ self.norm = norm_layer(embed_dim)
147
+
148
+ nn.init.trunc_normal_(self.pos_embed, std=.02)
149
+ nn.init.trunc_normal_(self.cls_token, std=.02)
150
+ if self.slot_through:
151
+ nn.init.trunc_normal_(self.slot_embed, std=.02)
152
+ self.apply(self._init_weights)
153
+
154
+ def _init_weights(self, m):
155
+ if isinstance(m, nn.Linear):
156
+ nn.init.trunc_normal_(m.weight, std=.02)
157
+ if isinstance(m, nn.Linear) and m.bias is not None:
158
+ nn.init.constant_(m.bias, 0)
159
+ elif isinstance(m, nn.LayerNorm):
160
+ nn.init.constant_(m.bias, 0)
161
+ nn.init.constant_(m.weight, 1.0)
162
+
163
+ def interpolate_pos_encoding(self, x, w, h):
164
+ npatch = x.shape[1] - 1 - self.num_slots
165
+ N = self.pos_embed.shape[1] - 1 - self.num_slots
166
+ if npatch == N and w == h:
167
+ return self.pos_embed
168
+ class_pos_embed = self.pos_embed[:, 0]
169
+ patch_pos_embed = self.pos_embed[:, 1:1+npatch]
170
+ dim = x.shape[-1]
171
+ w0 = w // self.patch_embed.patch_size[0]
172
+ h0 = h // self.patch_embed.patch_size[1]
173
+ # we add a small number to avoid floating point error in the interpolation
174
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
175
+ w0, h0 = w0 + 0.1, h0 + 0.1
176
+ patch_pos_embed = nn.functional.interpolate(
177
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
178
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
179
+ mode='bicubic',
180
+ )
181
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
182
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
183
+
184
+ if self.slot_through:
185
+ slots_pos_embed = self.pos_embed[:, 1+npatch:]
186
+ slots_pos_embed = slots_pos_embed.view(1, -1, dim) # (1, num_slots, dim)
187
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed, slots_pos_embed), dim=1)
188
+
189
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
190
+
191
+ def prepare_tokens(self, x):
192
+ B, nc, w, h = x.shape
193
+ x = self.patch_embed(x)
194
+ x = x.flatten(2).transpose(1, 2)
195
+ if self.slot_through:
196
+ x = torch.cat((self.cls_token.expand(B, -1, -1), x, self.slot_embed.expand(B, -1, -1)), dim=1)
197
+ else:
198
+ x = torch.cat((self.cls_token.expand(B, -1, -1), x), dim=1)
199
+ x = x + self.interpolate_pos_encoding(x, w, h)
200
+ return self.pos_drop(x)
201
+
202
+ def forward(self, x, is_causal=True):
203
+ x = self.prepare_tokens(x)
204
+ if is_causal and self.slot_through:
205
+ attn_mask = torch.ones(x.shape[1], x.shape[1], device=x.device, dtype=torch.bool)
206
+ # slots are causal to each other
207
+ causal_mask = torch.ones(self.num_slots, self.num_slots, device=x.device, dtype=torch.bool).tril(diagonal=0)
208
+ attn_mask[-self.num_slots:, -self.num_slots:] = causal_mask
209
+ # cls token and patches should not see slots
210
+ attn_mask[:-self.num_slots, -self.num_slots:] = False
211
+ else:
212
+ attn_mask = None
213
+
214
+ for blk in self.blocks:
215
+ x = blk(x, attn_mask=attn_mask)
216
+
217
+ x = self.norm(x)
218
+ if self.slot_through:
219
+ outcome = x[:, -self.num_slots:] # return the slots
220
+ else:
221
+ outcome = x[:, 1:] # return the patches
222
+ return outcome
223
+
224
+ def get_intermediate_layers(self, x, n=1):
225
+ x = self.prepare_tokens(x)
226
+ # we return the output tokens from the `n` last blocks
227
+ output = []
228
+ for i, blk in enumerate(self.blocks):
229
+ x = blk(x)
230
+ if len(self.blocks) - i <= n:
231
+ output.append(self.norm(x))
232
+ return output
233
+
234
+
235
+ def vit_tiny_patch16(**kwargs):
236
+ model = VisionTransformer(
237
+ patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
238
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
239
+ return model
240
+
241
+
242
+ def vit_small_patch16(**kwargs):
243
+ model = VisionTransformer(
244
+ patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
245
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
246
+ return model
247
+
248
+
249
+ def vit_base_patch16(**kwargs):
250
+ model = VisionTransformer(
251
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
252
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
253
+ return model
254
+
255
+
256
+ def vit_large_patch16(**kwargs):
257
+ model = VisionTransformer(
258
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
259
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
260
+ return model
261
+
262
+
263
+ def vit_huge_patch14(**kwargs):
264
+ model = VisionTransformer(
265
+ patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
266
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
267
+ return model
paintmind/stage2/__init__.py ADDED
File without changes
paintmind/stage2/causaldit.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+
8
+ from typing import Optional
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from timm.models.vision_transformer import Mlp
15
+ from paintmind.stage1.diffusion_transfomers import TimestepEmbedder, LabelEmbedder, FinalLayer, modulate
16
+ from paintmind.stage1.diffusion import create_diffusion
17
+ from paintmind.stage1.transport import create_transport, Sampler
18
+
19
+
20
+ class GeneralizedCausalAttention(nn.Module):
21
+ def __init__(self, dim, num_heads, norm_layer=nn.LayerNorm):
22
+ super().__init__()
23
+ assert dim % num_heads == 0
24
+ self.num_heads = num_heads
25
+ self.head_dim = dim // num_heads
26
+ self.scale = self.head_dim ** -0.5
27
+ self.qkv = nn.Linear(dim, 3 * dim, bias=False)
28
+ self.proj = nn.Linear(dim, dim)
29
+ self.q_norm = norm_layer(self.head_dim)
30
+ self.k_norm = norm_layer(self.head_dim)
31
+
32
+ def _forward_kv_cache(
33
+ self,
34
+ x: torch.Tensor,
35
+ layer_index: int,
36
+ kv_cache: dict,
37
+ update_kv_cache: bool = False,
38
+ ):
39
+ N, Lq = x.shape[:2]
40
+ qkv = self.qkv(x).reshape(N, Lq, 3, self.num_heads, self.head_dim)
41
+ q, curr_k, curr_v = qkv.permute(2, 0, 3, 1, 4).unbind(0) # N, nhead, Lq, dhead
42
+ q = self.q_norm(q)
43
+ curr_k = self.k_norm(curr_k)
44
+
45
+ if kv_cache[layer_index]["k"] is not None:
46
+ k = kv_cache[layer_index]["k"]
47
+ v = kv_cache[layer_index]["v"]
48
+ k = torch.cat((k, curr_k), dim=2)
49
+ v = torch.cat((v, curr_v), dim=2)
50
+ else:
51
+ k = curr_k
52
+ v = curr_v
53
+
54
+ if update_kv_cache:
55
+ kv_cache[layer_index]["k"] = k
56
+ kv_cache[layer_index]["v"] = v
57
+
58
+ return self._forward_sdpa(q, k, v, attn_mask=None)
59
+
60
+ def _forward(self, x, attn_mask):
61
+ N, L = x.shape[:2]
62
+ qkv = self.qkv(x).reshape(N, L, 3, self.num_heads, self.head_dim)
63
+ q, k, v = qkv.permute(2, 0, 3, 1, 4).unbind(0) # N, nhead, L, dhead
64
+ q = self.q_norm(q)
65
+ k = self.k_norm(k)
66
+ return self._forward_sdpa(q, k, v, attn_mask)
67
+
68
+ def _forward_sdpa(self, q, k, v, attn_mask):
69
+ N, _, Lq, _ = q.shape
70
+ x = F.scaled_dot_product_attention(
71
+ q, k, v,
72
+ attn_mask=attn_mask,
73
+ )
74
+
75
+ x = x.transpose(1, 2).reshape(N, Lq, -1)
76
+ x = self.proj(x)
77
+ return x
78
+
79
+ def forward(
80
+ self,
81
+ x: torch.Tensor,
82
+ attn_mask: Optional[torch.Tensor] = None,
83
+ kv_cache: Optional[dict] = None,
84
+ layer_index: Optional[int] = None,
85
+ update_kv_cache: Optional[bool] = None,
86
+ ) -> torch.Tensor:
87
+ if kv_cache is not None:
88
+ return self._forward_kv_cache(
89
+ x,
90
+ kv_cache=kv_cache,
91
+ layer_index=layer_index,
92
+ update_kv_cache=update_kv_cache
93
+ )
94
+ else:
95
+ return self._forward(x, attn_mask)
96
+
97
+
98
+ class DiTBlock(nn.Module):
99
+ def __init__(
100
+ self,
101
+ hidden_size,
102
+ num_heads,
103
+ mlp_ratio=4.0,
104
+ norm_layer=nn.LayerNorm,
105
+ causal_fusion=False,
106
+ deep_supervision=False,
107
+ output_dim=None,
108
+ ):
109
+ super().__init__()
110
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
111
+ self.attn = GeneralizedCausalAttention(hidden_size, num_heads=num_heads, norm_layer=norm_layer)
112
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
113
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
114
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
115
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
116
+ self.causal_fusion = causal_fusion
117
+ if not causal_fusion:
118
+ self.adaLN_modulation = nn.Sequential(
119
+ nn.SiLU(),
120
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
121
+ )
122
+ self.deep_supervision = deep_supervision
123
+ if deep_supervision:
124
+ if not causal_fusion:
125
+ self.final_layer = FinalLayer(hidden_size, 1, output_dim)
126
+ else:
127
+ self.final_layer = nn.Sequential(
128
+ nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6),
129
+ nn.Linear(hidden_size, output_dim)
130
+ )
131
+
132
+ def forward_causal_dit(self, x, c, **kwargs):
133
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
134
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), **kwargs)
135
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
136
+ if self.deep_supervision and self.training:
137
+ return x, self.final_layer(x, c)
138
+ else:
139
+ return x
140
+
141
+ def forward_causal_fusion(self, x, **kwargs):
142
+ x = x + self.attn(self.norm1(x), **kwargs)
143
+ x = x + self.mlp(self.norm2(x))
144
+ if self.deep_supervision and self.training:
145
+ return x, self.final_layer(x)
146
+ else:
147
+ return x
148
+
149
+ def forward(self, x, c=None, **kwargs):
150
+ if self.causal_fusion:
151
+ return self.forward_causal_fusion(x, **kwargs)
152
+ else:
153
+ return self.forward_causal_dit(x, c, **kwargs)
154
+
155
+ class CausalDiT(nn.Module):
156
+ def __init__(
157
+ self,
158
+ num_slots=16,
159
+ slot_dim=256,
160
+ hidden_size=1152,
161
+ depth=28,
162
+ num_heads=16,
163
+ mlp_ratio=4.0,
164
+ class_dropout_prob=0.1,
165
+ num_classes=1000,
166
+ num_sampling_steps='250',
167
+ use_si=False,
168
+ predict_xstart=False,
169
+ causal_fusion=False,
170
+ deep_supervision=False,
171
+ cls_token_num=0,
172
+ **kwargs
173
+ ):
174
+ super().__init__()
175
+ self.num_slots = num_slots
176
+ self.slot_dim = slot_dim
177
+ self.num_heads = num_heads
178
+ self.hidden_size = hidden_size
179
+ self.num_classes = num_classes
180
+ self.output_dim = slot_dim * 2 if not use_si else slot_dim
181
+ self.x_embedder = nn.Linear(slot_dim, hidden_size)
182
+ self.t_embedder = TimestepEmbedder(hidden_size)
183
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
184
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_slots, hidden_size))
185
+ blocks = [
186
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, causal_fusion=causal_fusion,
187
+ deep_supervision=deep_supervision, output_dim=self.output_dim) for _ in range(depth - 1)
188
+ ]
189
+ blocks.append(DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, causal_fusion=causal_fusion))
190
+ self.blocks = nn.ModuleList(blocks)
191
+ self.cls_token_num = cls_token_num
192
+ self.causal_fusion = causal_fusion
193
+ self.deep_supervision = deep_supervision
194
+ if not causal_fusion:
195
+ self.final_layer = FinalLayer(hidden_size, 1, self.output_dim)
196
+ else:
197
+ self.cond_pos_embed = nn.Parameter(torch.zeros(1, cls_token_num, hidden_size))
198
+ self.final_layer = nn.Sequential(
199
+ nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6),
200
+ nn.Linear(hidden_size, self.output_dim)
201
+ )
202
+
203
+ self.initialize_weights()
204
+
205
+ self.use_si = use_si
206
+ if not use_si:
207
+ self.train_diffusion = create_diffusion(timestep_respacing="", predict_xstart=predict_xstart)
208
+ self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, predict_xstart=predict_xstart)
209
+ else:
210
+ self.transport = create_transport()
211
+ self.sampler = Sampler(self.transport)
212
+
213
+
214
+ def initialize_weights(self):
215
+ # Initialize transformer layers:
216
+ def _basic_init(module):
217
+ if isinstance(module, nn.Linear):
218
+ torch.nn.init.xavier_uniform_(module.weight)
219
+ if module.bias is not None:
220
+ nn.init.constant_(module.bias, 0)
221
+ if self.causal_fusion and isinstance(module, nn.LayerNorm):
222
+ if module.weight is not None:
223
+ nn.init.constant_(module.weight, 1.0)
224
+ if module.bias is not None:
225
+ nn.init.constant_(module.bias, 0)
226
+ self.apply(_basic_init)
227
+
228
+ # Initialize pos_embed:
229
+ nn.init.normal_(self.pos_embed, std=0.02)
230
+ # Initialize label embedding table:
231
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
232
+ if self.causal_fusion:
233
+ nn.init.normal_(self.cond_pos_embed, std=0.02)
234
+
235
+ # Initialize timestep embedding MLP:
236
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
237
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
238
+
239
+ # Zero-out adaLN modulation layers in DiT blocks:
240
+ for i, block in enumerate(self.blocks):
241
+ if not self.causal_fusion:
242
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
243
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
244
+ if self.deep_supervision and i < len(self.blocks) - 1:
245
+ if not self.causal_fusion:
246
+ nn.init.constant_(block.final_layer.adaLN_modulation[-1].weight, 0)
247
+ nn.init.constant_(block.final_layer.adaLN_modulation[-1].bias, 0)
248
+ nn.init.constant_(block.final_layer.linear.weight, 0)
249
+ nn.init.constant_(block.final_layer.linear.bias, 0)
250
+ else:
251
+ nn.init.constant_(block.final_layer[1].weight, 0)
252
+ nn.init.constant_(block.final_layer[1].bias, 0)
253
+
254
+ # Zero-out output layers:
255
+ if not self.causal_fusion:
256
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
257
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
258
+ nn.init.constant_(self.final_layer.linear.weight, 0)
259
+ nn.init.constant_(self.final_layer.linear.bias, 0)
260
+ else:
261
+ nn.init.constant_(self.final_layer[1].weight, 0)
262
+ nn.init.constant_(self.final_layer[1].bias, 0)
263
+
264
+ def forward_cache_update(self, x, kv_cache):
265
+ for idx, block in enumerate(self.blocks):
266
+ x = block(x, layer_index=idx, kv_cache=kv_cache, update_kv_cache=True)
267
+ return None
268
+
269
+ def _forward_inference(self, xn, t, y, kv_cache, pos_embed, attn_mask, context):
270
+ xn = xn.transpose(1, 2)
271
+ if context is not None:
272
+ xn = torch.cat([context, xn], dim=1)
273
+ xn = self.x_embedder(xn) + pos_embed
274
+ y = self.y_embedder(y, self.training)
275
+ t = self.t_embedder(t)
276
+ if not self.causal_fusion:
277
+ c = t + y
278
+ else:
279
+ y = y.unsqueeze(1).expand(-1, self.cls_token_num, -1) + self.cond_pos_embed
280
+ xn = torch.cat([y, xn], dim=1)
281
+ t = t.unsqueeze(1)
282
+ xn = xn + t
283
+ c = None
284
+
285
+ for idx, block in enumerate(self.blocks):
286
+ xn = block(xn, c, attn_mask=attn_mask, layer_index=idx, kv_cache=kv_cache, update_kv_cache=False)
287
+
288
+ xn = xn[:, -1].unsqueeze(1)
289
+ xn = self._forward_final_layer(xn, c)
290
+ return xn.transpose(1, 2)
291
+
292
+ def _forward_inference_with_cfg(self, xn, t, y, kv_cache, pos_embed, attn_mask, context, cfg_scale):
293
+ half = xn[: len(xn) // 2]
294
+ combined = torch.cat([half, half], dim=0)
295
+ model_out = self._forward_inference(combined, t, y, kv_cache, pos_embed, attn_mask, context)
296
+ eps, rest = model_out[:, :self.slot_dim], model_out[:, self.slot_dim:]
297
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
298
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
299
+ eps = torch.cat([half_eps, half_eps], dim=0)
300
+ return torch.cat([eps, rest], dim=1)
301
+
302
+ def sample(self, y, pos_embed, attn_mask, context=None, cfg=1.0):
303
+ # diffusion loss sampling
304
+ device = y.device
305
+ if not cfg == 1.0:
306
+ noise = torch.randn(y.shape[0] // 2, self.slot_dim, 1, device=device)
307
+ noise = torch.cat([noise, noise], dim=0)
308
+ model_kwargs = dict(y=y, kv_cache=None, pos_embed=pos_embed, attn_mask=attn_mask, context=context, cfg_scale=cfg)
309
+ sample_fn = self._forward_inference_with_cfg
310
+ else:
311
+ noise = torch.randn(y.shape[0], self.slot_dim, 1, device=device)
312
+ model_kwargs = dict(y=y, kv_cache=None, pos_embed=pos_embed, attn_mask=attn_mask, context=context)
313
+ sample_fn = self._forward_inference
314
+
315
+ if not self.use_si:
316
+ sampled_token_latent = self.gen_diffusion.p_sample_loop(
317
+ sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False,
318
+ device=device
319
+ )
320
+ else:
321
+ sde_sample_fn = self.sampler.sample_sde(diffusion_form="sigma")
322
+ sampled_token_latent = sde_sample_fn(noise, sample_fn, **model_kwargs)[-1]
323
+
324
+ return sampled_token_latent.transpose(1, 2)
325
+
326
+ def _forward_final_layer(self, x, c):
327
+ if not self.causal_fusion:
328
+ return self.final_layer(x, c)
329
+ else:
330
+ return self.final_layer(x)
331
+
332
+ def forward_train(self, xn, t, y, xc):
333
+ """
334
+ Args:
335
+ xn: noised latent
336
+ t: time step
337
+ y: condition
338
+ xc: clean latent
339
+ """
340
+
341
+ xc = self.x_embedder(xc.transpose(1, 2)) + self.pos_embed
342
+ xn = self.x_embedder(xn.transpose(1, 2)) + self.pos_embed
343
+
344
+ t = self.t_embedder(t)
345
+ y = self.y_embedder(y, self.training)
346
+ if not self.causal_fusion:
347
+ c = t + y
348
+ else:
349
+ y = y.unsqueeze(1).expand(-1, self.cls_token_num, -1) + self.cond_pos_embed
350
+ t = t.unsqueeze(1)
351
+ xc = torch.cat((y, xc), dim=1)
352
+ xn = xn + t
353
+ c = None
354
+
355
+ # forward transformer
356
+ x = torch.cat((xc, xn), dim=1)
357
+ attn_mask = get_attn_mask(self.cls_token_num, self.num_slots).to(x.device)
358
+
359
+ if self.deep_supervision and self.training:
360
+ xs = []
361
+ for block in self.blocks[:-1]:
362
+ x, x_hat = block(x, c=c, attn_mask=attn_mask)
363
+ xs.append(x_hat[:, -self.num_slots:])
364
+ x = self.blocks[-1](x, c=c, attn_mask=attn_mask)
365
+ x = self._forward_final_layer(x[:, -self.num_slots:], c)
366
+ xs.append(x)
367
+ # N, B, L, C -> B, C, L, N
368
+ return torch.stack(xs, dim=0).permute(1, 3, 2, 0)
369
+ else:
370
+ for block in self.blocks:
371
+ x = block(x, c=c, attn_mask=attn_mask)
372
+ x = x[:, -self.num_slots:]
373
+ return self._forward_final_layer(x, c).transpose(1, 2)
374
+
375
+ def forward(self, slots, targets):
376
+ slots = slots.transpose(1, 2)
377
+ model_kwargs = dict(y=targets, xc=slots)
378
+ if not self.use_si:
379
+ t = torch.randint(0, self.train_diffusion.num_timesteps, (slots.shape[0],), device=slots.device)
380
+ loss_dict = self.train_diffusion.training_losses(self.forward_train, slots, t, model_kwargs)
381
+ else:
382
+ loss_dict = self.transport.training_losses(self.forward_train, slots, model_kwargs)
383
+ loss = loss_dict["loss"]
384
+ return loss.mean()
385
+
386
+
387
+ def CausalDiT_L(**kwargs):
388
+ return CausalDiT(depth=24, hidden_size=1024, num_heads=16, **kwargs)
389
+
390
+
391
+ def CausalDiT_XL(**kwargs):
392
+ return CausalDiT(depth=32, hidden_size=1280, num_heads=20, **kwargs)
393
+
394
+
395
+ def CausalDiT_H(**kwargs):
396
+ return CausalDiT(depth=48, hidden_size=1408, num_heads=22, **kwargs)
397
+
398
+
399
+ CausalDiT_models = {
400
+ "CausalDiT-L": CausalDiT_L,
401
+ "CausalDiT-XL": CausalDiT_XL,
402
+ "CausalDiT-H": CausalDiT_H
403
+ }
404
+
405
+ def get_attn_mask(context_len, sample_len):
406
+ padder_len = 1 if context_len <= 0 else context_len
407
+ seq_len = padder_len + sample_len * 2
408
+ attn_mask = torch.eye(seq_len, dtype=bool)
409
+ triangle = torch.ones(sample_len, sample_len, dtype=bool).tril()
410
+ attn_mask[-sample_len:, :sample_len] = triangle
411
+ if padder_len == 1:
412
+ return attn_mask[1:, 1:]
413
+ else:
414
+ return attn_mask
415
+
416
+ # if context_len == 0
417
+ # 100000
418
+ # 010000
419
+ # 001000
420
+ # 000100
421
+ # 100010
422
+ # 110001
paintmind/stage2/diffloss.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+ import math
5
+
6
+ from paintmind.stage1.diffusion import create_diffusion
7
+ from paintmind.stage1.transport import create_transport, Sampler
8
+
9
+
10
+ class DiffLoss(nn.Module):
11
+ """Diffusion Loss"""
12
+ def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, predict_xstart=False, use_si=False, deep_supervision=False, token_drop_prob=0.0, cond_method="adaln", decoupled_cfg=True):
13
+ super(DiffLoss, self).__init__()
14
+ self.in_channels = target_channels
15
+ self.net = SimpleMLPAdaLN(
16
+ in_channels=target_channels,
17
+ model_channels=width,
18
+ out_channels=target_channels * 2 if not use_si else target_channels, # for vlb loss
19
+ z_channels=z_channels,
20
+ num_res_blocks=depth,
21
+ deep_supervision=deep_supervision,
22
+ token_drop_prob=token_drop_prob,
23
+ cond_method=cond_method,
24
+ decoupled_cfg=decoupled_cfg,
25
+ )
26
+ self.use_si = use_si
27
+ if not use_si:
28
+ self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine", predict_xstart=predict_xstart)
29
+ self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine", predict_xstart=predict_xstart)
30
+ else:
31
+ self.transport = create_transport()
32
+ self.sampler = Sampler(self.transport)
33
+
34
+ def forward(self, target, z, mask=None):
35
+ model_kwargs = dict(c=z)
36
+ if not self.use_si:
37
+ t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
38
+ loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs)
39
+ else:
40
+ loss_dict = self.transport.training_losses(self.net, target, model_kwargs)
41
+ loss = loss_dict["loss"]
42
+ if mask is not None:
43
+ loss = (loss * mask).sum() / mask.sum()
44
+ return loss.mean()
45
+
46
+ def sample(self, z, temperature=1.0, cfg=1.0):
47
+ # diffusion loss sampling
48
+ device = z.device
49
+ if not cfg == 1.0:
50
+ noise = torch.randn(z.shape[0] // 2, self.in_channels, device=device)
51
+ noise = torch.cat([noise, noise], dim=0)
52
+ model_kwargs = dict(c=z, cfg_scale=cfg)
53
+ sample_fn = self.net.forward_with_cfg
54
+ else:
55
+ noise = torch.randn(z.shape[0], self.in_channels, device=device)
56
+ model_kwargs = dict(c=z)
57
+ sample_fn = self.net.forward
58
+
59
+ if not self.use_si:
60
+ sampled_token_latent = self.gen_diffusion.p_sample_loop(
61
+ sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False,
62
+ temperature=temperature, device=device
63
+ )
64
+ else:
65
+ sde_sample_fn = self.sampler.sample_sde(diffusion_form="sigma", temperature=temperature)
66
+ sampled_token_latent = sde_sample_fn(noise, sample_fn, **model_kwargs)[-1]
67
+ if cfg != 1.0:
68
+ sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0)
69
+ return sampled_token_latent
70
+
71
+
72
+ def modulate(x, shift, scale):
73
+ return x * (1 + scale) + shift
74
+
75
+
76
+ class TimestepEmbedder(nn.Module):
77
+ """
78
+ Embeds scalar timesteps into vector representations.
79
+ """
80
+ def __init__(self, hidden_size, frequency_embedding_size=256):
81
+ super().__init__()
82
+ self.mlp = nn.Sequential(
83
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
84
+ nn.SiLU(),
85
+ nn.Linear(hidden_size, hidden_size, bias=True),
86
+ )
87
+ self.frequency_embedding_size = frequency_embedding_size
88
+
89
+ @staticmethod
90
+ def timestep_embedding(t, dim, max_period=10000):
91
+ """
92
+ Create sinusoidal timestep embeddings.
93
+ :param t: a 1-D Tensor of N indices, one per batch element.
94
+ These may be fractional.
95
+ :param dim: the dimension of the output.
96
+ :param max_period: controls the minimum frequency of the embeddings.
97
+ :return: an (N, D) Tensor of positional embeddings.
98
+ """
99
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
100
+ half = dim // 2
101
+ freqs = torch.exp(
102
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
103
+ ).to(device=t.device)
104
+ args = t[:, None].float() * freqs[None]
105
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
106
+ if dim % 2:
107
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
108
+ return embedding
109
+
110
+ def forward(self, t):
111
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
112
+ t_emb = self.mlp(t_freq)
113
+ return t_emb
114
+
115
+
116
+ class ResBlock(nn.Module):
117
+ """
118
+ A residual block with AdaLN for timestep and optional concatenation for condition.
119
+ """
120
+ def __init__(
121
+ self,
122
+ channels,
123
+ out_channels=None,
124
+ deep_supervision=False,
125
+ cond_method="adaln",
126
+ ):
127
+ super().__init__()
128
+ self.channels = channels
129
+ self.deep_supervision = deep_supervision
130
+ self.cond_method = cond_method
131
+
132
+ self.in_ln = nn.LayerNorm(channels, eps=1e-6)
133
+ self.adaLN_modulation = nn.Sequential(
134
+ nn.SiLU(),
135
+ nn.Linear(channels, 3 * channels, bias=True)
136
+ )
137
+
138
+ # Input dimension depends on conditioning method
139
+ mlp_in_dim = channels * 2 if cond_method == "concat" else channels
140
+ self.mlp = nn.Sequential(
141
+ nn.Linear(mlp_in_dim, channels, bias=True),
142
+ nn.SiLU(),
143
+ nn.Linear(channels, channels, bias=True),
144
+ )
145
+
146
+ if self.deep_supervision:
147
+ self.final_layer = FinalLayer(channels, out_channels, cond_method=cond_method)
148
+
149
+ def forward(self, x, t, c=None):
150
+ # Apply timestep embedding via AdaLN
151
+ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(t).chunk(3, dim=-1)
152
+ h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
153
+
154
+ # Concatenate condition if using concat method
155
+ if self.cond_method == "concat" and c is not None:
156
+ h = torch.cat([h, c], dim=-1)
157
+
158
+ h = self.mlp(h)
159
+ x = x + gate_mlp * h
160
+
161
+ if self.deep_supervision and self.training:
162
+ return x, self.final_layer(x, t, c)
163
+ return x
164
+
165
+
166
+ class FinalLayer(nn.Module):
167
+ """
168
+ Final layer with AdaLN for timestep and optional concatenation for condition.
169
+ """
170
+ def __init__(self, model_channels, out_channels, cond_method="adaln"):
171
+ super().__init__()
172
+ self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
173
+ self.cond_method = cond_method
174
+
175
+ self.adaLN_modulation = nn.Sequential(
176
+ nn.SiLU(),
177
+ nn.Linear(model_channels, 2 * model_channels, bias=True)
178
+ )
179
+
180
+ # Output dimension depends on conditioning method
181
+ linear_in_dim = model_channels * 2 if cond_method == "concat" else model_channels
182
+ self.linear = nn.Linear(linear_in_dim, out_channels, bias=True)
183
+
184
+ def forward(self, x, t, c=None):
185
+ # Apply timestep embedding via AdaLN
186
+ shift, scale = self.adaLN_modulation(t).chunk(2, dim=-1)
187
+ x = modulate(self.norm_final(x), shift, scale)
188
+
189
+ # Concatenate condition if using concat method
190
+ if self.cond_method == "concat" and c is not None:
191
+ x = torch.cat([x, c], dim=-1)
192
+
193
+ return self.linear(x)
194
+
195
+
196
+ class SimpleMLPAdaLN(nn.Module):
197
+ """
198
+ MLP for Diffusion Loss with AdaLN for timestep and optional concatenation for condition.
199
+ """
200
+ def __init__(
201
+ self,
202
+ in_channels,
203
+ model_channels,
204
+ out_channels,
205
+ z_channels,
206
+ num_res_blocks,
207
+ deep_supervision=False,
208
+ token_drop_prob=0.0,
209
+ cond_method="adaln",
210
+ decoupled_cfg=True,
211
+ ):
212
+ super().__init__()
213
+ self.in_channels = in_channels
214
+ self.model_channels = model_channels
215
+ self.out_channels = out_channels
216
+ self.deep_supervision = deep_supervision
217
+ self.token_drop_prob = token_drop_prob
218
+ self.cond_method = cond_method
219
+ self.decoupled_cfg = decoupled_cfg
220
+ if decoupled_cfg and token_drop_prob > 0.0:
221
+ self.null_token = nn.Parameter(torch.zeros(1, z_channels))
222
+
223
+ self.time_embed = TimestepEmbedder(model_channels)
224
+ self.cond_embed = nn.Linear(z_channels, model_channels)
225
+ self.input_proj = nn.Linear(in_channels, model_channels)
226
+
227
+ # Create residual blocks
228
+ res_blocks = []
229
+ for i in range(num_res_blocks - 1):
230
+ res_blocks.append(ResBlock(model_channels, out_channels, deep_supervision, cond_method))
231
+ res_blocks.append(ResBlock(model_channels, cond_method=cond_method))
232
+ self.res_blocks = nn.ModuleList(res_blocks)
233
+
234
+ self.final_layer = FinalLayer(model_channels, out_channels, cond_method=cond_method)
235
+ self.initialize_weights()
236
+
237
+ def initialize_weights(self):
238
+ # Basic initialization for all linear layers
239
+ def _basic_init(module):
240
+ if isinstance(module, nn.Linear):
241
+ torch.nn.init.xavier_uniform_(module.weight)
242
+ if module.bias is not None:
243
+ nn.init.constant_(module.bias, 0)
244
+ self.apply(_basic_init)
245
+
246
+ # Initialize timestep embedding MLP
247
+ nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
248
+ nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
249
+
250
+ if self.token_drop_prob > 0.0:
251
+ nn.init.normal_(self.null_token, std=0.02)
252
+
253
+ # Zero-out adaLN modulation layers (always used for timestep)
254
+ for i, block in enumerate(self.res_blocks):
255
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
256
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
257
+ if self.deep_supervision and i < len(self.res_blocks) - 1:
258
+ nn.init.constant_(block.final_layer.adaLN_modulation[-1].weight, 0)
259
+ nn.init.constant_(block.final_layer.adaLN_modulation[-1].bias, 0)
260
+ nn.init.constant_(block.final_layer.linear.weight, 0)
261
+ nn.init.constant_(block.final_layer.linear.bias, 0)
262
+
263
+ # Zero-out output layers
264
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
265
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
266
+ nn.init.constant_(self.final_layer.linear.weight, 0)
267
+ nn.init.constant_(self.final_layer.linear.bias, 0)
268
+
269
+ def forward(self, x, t, c):
270
+ """
271
+ Apply the model to an input batch.
272
+ :param x: an [N x C] Tensor of inputs.
273
+ :param t: a 1-D batch of timesteps.
274
+ :param c: conditioning from AR transformer.
275
+ :return: an [N x C] Tensor of outputs.
276
+ """
277
+ x = self.input_proj(x)
278
+ t_emb = self.time_embed(t)
279
+
280
+ # Apply token dropout if needed
281
+ if self.token_drop_prob > 0.0 and self.training:
282
+ drop_ids = torch.rand(c.shape[0], 1, device=c.device) < self.token_drop_prob
283
+ c = torch.where(drop_ids, self.null_token, c)
284
+ c_emb = self.cond_embed(c)
285
+
286
+ # Prepare conditioning based on method
287
+ if self.cond_method == "adaln":
288
+ t_combined, c_for_concat = t_emb + c_emb, None
289
+ else: # concat
290
+ t_combined, c_for_concat = t_emb, c_emb
291
+
292
+ if self.deep_supervision and self.training:
293
+ xs = []
294
+ for block in self.res_blocks[:-1]:
295
+ x, x_hat = block(x, t_combined, c_for_concat)
296
+ xs.append(x_hat)
297
+ x = self.res_blocks[-1](x, t_combined, c_for_concat)
298
+ x = self.final_layer(x, t_combined, c_for_concat)
299
+ xs.append(x)
300
+ return torch.stack(xs, dim=-1)
301
+ else:
302
+ for block in self.res_blocks:
303
+ x = block(x, t_combined, c_for_concat)
304
+ return self.final_layer(x, t_combined, c_for_concat)
305
+
306
+ def forward_with_cfg(self, x, t, c, cfg_scale):
307
+ half = x[: len(x) // 2]
308
+ combined = torch.cat([half, half], dim=0)
309
+ model_out = self.forward(combined, t, c)
310
+ eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
311
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
312
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
313
+ eps = torch.cat([half_eps, half_eps], dim=0)
314
+ return torch.cat([eps, rest], dim=1)
paintmind/stage2/generate.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
3
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
4
+ import torch
5
+
6
+ def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, diff_cfg: float = 1.0, temperature: float = 1.0):
7
+ tokens = model(None, cond_idx, input_pos, cfg=cfg_scale, diff_cfg=diff_cfg, temperature=temperature)
8
+ return tokens.unsqueeze(1)
9
+
10
+
11
+ def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, diff_cfg: float = 1.0, temperature: float = 1.0):
12
+ assert input_pos.shape[-1] == 1
13
+ if cfg_scale > 1.0:
14
+ x = torch.cat([x, x])
15
+ tokens = model(x, cond_idx=None, input_pos=input_pos, cfg=cfg_scale, diff_cfg=diff_cfg, temperature=temperature)
16
+ return tokens
17
+
18
+
19
+ def decode_n_tokens(
20
+ model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int,
21
+ cfg_scale: float, diff_cfg: float = 1.0, temperature: float = 1.0, cfg_schedule = "constant", diff_cfg_schedule = "constant"):
22
+ new_tokens = []
23
+ for i in range(num_new_tokens):
24
+ cfg_iter = get_cfg(cfg_scale, i + 1, num_new_tokens + 1, cfg_schedule)
25
+ diff_cfg_iter = get_cfg(diff_cfg, i + 1, num_new_tokens + 1, diff_cfg_schedule)
26
+ next_token = decode_one_token(model, cur_token, input_pos, cfg_iter, diff_cfg=diff_cfg_iter, temperature=temperature).unsqueeze(1)
27
+ input_pos += 1
28
+ new_tokens.append(next_token.clone())
29
+ cur_token = next_token
30
+
31
+ return new_tokens
32
+
33
+
34
+ @torch.no_grad()
35
+ def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, diff_cfg: float = 1.0, temperature: float = 1.0, cfg_schedule = "constant", diff_cfg_schedule = "constant"):
36
+ if cfg_scale > 1.0:
37
+ cond_null = torch.ones_like(cond) * model.num_classes
38
+ cond_combined = torch.cat([cond, cond_null])
39
+ else:
40
+ cond_combined = cond
41
+ T = model.cls_token_num
42
+
43
+ T_new = T + max_new_tokens
44
+ max_seq_length = T_new
45
+ max_batch_size = cond.shape[0]
46
+
47
+ device = cond.device
48
+ dtype = model.z_proj.weight.dtype
49
+ if torch.is_autocast_enabled():
50
+ dtype = torch.get_autocast_dtype(device_type=device.type)
51
+ with torch.device(device):
52
+ max_batch_size_cfg = max_batch_size * 2 if cfg_scale > 1.0 else max_batch_size
53
+ model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, dtype=dtype)
54
+
55
+ if emb_masks is not None:
56
+ assert emb_masks.shape[0] == max_batch_size
57
+ assert emb_masks.shape[-1] == T
58
+ if cfg_scale > 1.0:
59
+ model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1)
60
+ else:
61
+ model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1)
62
+
63
+ eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device)
64
+ model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix
65
+
66
+ # create an empty tensor of the expected final shape and fill in the current tokens
67
+ seq = torch.empty((max_batch_size, T_new, model.slot_dim), dtype=dtype, device=device)
68
+
69
+ input_pos = torch.arange(0, T, device=device)
70
+ cfg_iter = get_cfg(cfg_scale, 0, max_new_tokens, cfg_schedule)
71
+ diff_cfg_iter = get_cfg(diff_cfg, 0, max_new_tokens, diff_cfg_schedule)
72
+ next_token = prefill(model, cond_combined, input_pos, cfg_iter, diff_cfg=diff_cfg_iter, temperature=temperature)
73
+ seq[:, T:T+1] = next_token
74
+
75
+ if max_new_tokens > 1:
76
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
77
+ generated_tokens = decode_n_tokens(model, next_token, input_pos, max_new_tokens - 1, cfg_scale, diff_cfg=diff_cfg, temperature=temperature, cfg_schedule=cfg_schedule, diff_cfg_schedule=diff_cfg_schedule)
78
+ seq[:, T+1:] = torch.cat(generated_tokens, dim=1)
79
+
80
+ model.reset_caches()
81
+ return seq[:, T:]
82
+
83
+
84
+ def get_cfg(cfg, cur_step, total_step, cfg_schedule="constant"):
85
+ if cfg_schedule == "linear":
86
+ return 1 + (cfg - 1) * (cur_step + 1) / total_step
87
+ elif cfg_schedule == "inv_linear":
88
+ return 1 + (cfg - 1) * (total_step - cur_step - 1) / total_step
89
+ elif cfg_schedule == "constant":
90
+ return cfg
91
+ else:
92
+ raise NotImplementedError
93
+
94
+
95
+ @torch.no_grad()
96
+ def generate_causal_dit(model, cond, max_new_tokens, cfg_scale=1.0):
97
+ assert max_new_tokens == model.num_slots
98
+
99
+ batch_size = cond.shape[0]
100
+ device = cond.device
101
+
102
+ if cfg_scale > 1.0:
103
+ cond_null = torch.ones_like(cond) * model.num_classes
104
+ cond_combined = torch.cat([cond, cond_null])
105
+ else:
106
+ cond_combined = cond
107
+
108
+ cur_tokens = []
109
+ for i in range(max_new_tokens):
110
+ pos_embed = model.pos_embed[:, :i + 1].view(1, -1, model.hidden_size).expand(batch_size, -1, -1)
111
+ if cfg_scale > 1.0:
112
+ pos_embed = torch.cat([pos_embed, pos_embed], dim=0)
113
+
114
+ attn_mask = torch.ones(model.cls_token_num + i + 1, model.cls_token_num + i + 1, dtype=torch.bool).tril(diagonal=0).to(device)
115
+
116
+ context = torch.cat(cur_tokens, dim=1) if len(cur_tokens) > 0 else None
117
+ if cfg_scale > 1.0 and context is not None:
118
+ context = torch.cat([context, context], dim=0)
119
+
120
+ next_token = model.sample(cond_combined, pos_embed, attn_mask, context, cfg_scale)
121
+ if cfg_scale > 1.0:
122
+ next_token, _ = next_token.chunk(2, dim=0)
123
+ cur_tokens.append(next_token.clone())
124
+
125
+ seq = torch.cat(cur_tokens, dim=1)
126
+
127
+ return seq
paintmind/stage2/gpt.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # VQGAN: https://github.com/CompVis/taming-transformers/blob/master/taming/modules/transformer/mingpt.py
3
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
4
+ # nanoGPT: https://github.com/karpathy/nanoGPT/blob/master/model.py
5
+ # llama: https://github.com/facebookresearch/llama/blob/main/llama/model.py
6
+ # gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
7
+ # PixArt: https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
8
+ from dataclasses import dataclass
9
+ from typing import Optional, List, Union
10
+
11
+ import math
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.nn import functional as F
15
+
16
+ from paintmind.stage1.vision_transformers import DropPath
17
+ from paintmind.stage2.diffloss import DiffLoss
18
+
19
+ def find_multiple(n: int, k: int):
20
+ if n % k == 0:
21
+ return n
22
+ return n + k - (n % k)
23
+
24
+
25
+
26
+ #################################################################################
27
+ # Embedding Layers for Class Labels #
28
+ #################################################################################
29
+ class LabelEmbedder(nn.Module):
30
+ """
31
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
32
+ """
33
+ def __init__(self, num_classes, hidden_size, dropout_prob):
34
+ super().__init__()
35
+ use_cfg_embedding = dropout_prob > 0
36
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
37
+ self.num_classes = num_classes
38
+ self.dropout_prob = dropout_prob
39
+
40
+ def token_drop(self, labels, force_drop_ids=None):
41
+ """
42
+ Drops labels to enable classifier-free guidance.
43
+ """
44
+ if force_drop_ids is None:
45
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
46
+ else:
47
+ drop_ids = force_drop_ids == 1
48
+ labels = torch.where(drop_ids, self.num_classes, labels)
49
+ return labels
50
+
51
+ def forward(self, labels, train, force_drop_ids=None):
52
+ use_dropout = self.dropout_prob > 0
53
+ if (train and use_dropout) or (force_drop_ids is not None):
54
+ labels = self.token_drop(labels, force_drop_ids)
55
+ embeddings = self.embedding_table(labels).unsqueeze(1)
56
+ return embeddings
57
+
58
+
59
+ class MLP(nn.Module):
60
+ def __init__(self, in_features, hidden_features, out_features):
61
+ super().__init__()
62
+ out_features = out_features or in_features
63
+ hidden_features = hidden_features or in_features
64
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
65
+ self.act = nn.GELU(approximate='tanh')
66
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
67
+
68
+ def forward(self, x):
69
+ x = self.fc1(x)
70
+ x = self.act(x)
71
+ x = self.fc2(x)
72
+ return x
73
+
74
+
75
+ #################################################################################
76
+ # GPT Model #
77
+ #################################################################################
78
+ class RMSNorm(torch.nn.Module):
79
+ def __init__(self, dim: int, eps: float = 1e-5):
80
+ super().__init__()
81
+ self.eps = eps
82
+ self.weight = nn.Parameter(torch.ones(dim))
83
+
84
+ def _norm(self, x):
85
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
86
+
87
+ def forward(self, x):
88
+ output = self._norm(x.float()).type_as(x)
89
+ return output * self.weight
90
+
91
+
92
+ class FeedForward(nn.Module):
93
+ def __init__(
94
+ self,
95
+ dim: int,
96
+ multiple_of: int = 256,
97
+ ffn_dropout_p: float = 0.0,
98
+ ):
99
+ super().__init__()
100
+ hidden_dim = 4 * dim
101
+ hidden_dim = int(2 * hidden_dim / 3)
102
+ hidden_dim = find_multiple(hidden_dim, multiple_of)
103
+
104
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
105
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
106
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
107
+ self.ffn_dropout = nn.Dropout(ffn_dropout_p)
108
+
109
+ def forward(self, x):
110
+ return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
111
+
112
+
113
+ class KVCache(nn.Module):
114
+ def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype):
115
+ super().__init__()
116
+ cache_shape = (max_batch_size, n_head, max_seq_length, head_dim)
117
+ self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
118
+ self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
119
+
120
+ def update(self, input_pos, k_val, v_val):
121
+ # input_pos: [S], k_val: [B, H, S, D]
122
+ assert input_pos.shape[0] == k_val.shape[2]
123
+ k_out = self.k_cache
124
+ v_out = self.v_cache
125
+ k_out[:, :, input_pos] = k_val
126
+ v_out[:, :, input_pos] = v_val
127
+
128
+ return k_out, v_out
129
+
130
+
131
+ class Attention(nn.Module):
132
+ def __init__(
133
+ self,
134
+ dim: int,
135
+ n_head: int,
136
+ attn_dropout_p: float = 0.0,
137
+ resid_dropout_p: float = 0.1,
138
+ ):
139
+ super().__init__()
140
+ assert dim % n_head == 0
141
+ self.dim = dim
142
+ self.head_dim = dim // n_head
143
+ self.n_head = n_head
144
+
145
+ # key, query, value projections for all heads, but in a batch
146
+ self.wqkv = nn.Linear(dim, dim * 3, bias=False)
147
+ self.wo = nn.Linear(dim, dim, bias=False)
148
+ self.kv_cache = None
149
+
150
+ # regularization
151
+ self.attn_dropout_p = attn_dropout_p
152
+ self.resid_dropout = nn.Dropout(resid_dropout_p)
153
+
154
+ def forward(
155
+ self, x: torch.Tensor,
156
+ input_pos: Optional[torch.Tensor] = None,
157
+ mask: Optional[torch.Tensor] = None
158
+ ):
159
+ bsz, seqlen, _ = x.shape
160
+ xq, xk, xv = self.wqkv(x).split([self.dim, self.dim, self.dim], dim=-1)
161
+
162
+ xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
163
+ xk = xk.view(bsz, seqlen, self.n_head, self.head_dim)
164
+ xv = xv.view(bsz, seqlen, self.n_head, self.head_dim)
165
+
166
+ xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
167
+
168
+ if self.kv_cache is not None:
169
+ keys, values = self.kv_cache.update(input_pos, xk, xv)
170
+ else:
171
+ keys, values = xk, xv
172
+
173
+ output = F.scaled_dot_product_attention(
174
+ xq, keys, values,
175
+ attn_mask=mask,
176
+ is_causal=True if mask is None else False, # is_causal=False is for KV cache
177
+ dropout_p=self.attn_dropout_p if self.training else 0)
178
+
179
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
180
+
181
+ output = self.resid_dropout(self.wo(output))
182
+ return output
183
+
184
+
185
+ class TransformerBlock(nn.Module):
186
+ def __init__(
187
+ self,
188
+ dim: int,
189
+ n_head: int,
190
+ multiple_of: int = 256,
191
+ norm_eps: float = 1e-5,
192
+ attn_dropout_p: float = 0.0,
193
+ ffn_dropout_p: float = 0.1,
194
+ resid_dropout_p: float = 0.1,
195
+ drop_path: float = 0.0,
196
+ ):
197
+ super().__init__()
198
+ self.attention = Attention(
199
+ dim=dim,
200
+ n_head=n_head,
201
+ attn_dropout_p=attn_dropout_p,
202
+ resid_dropout_p=resid_dropout_p,
203
+ )
204
+ self.feed_forward = FeedForward(
205
+ dim=dim,
206
+ multiple_of=multiple_of,
207
+ ffn_dropout_p=ffn_dropout_p,
208
+ )
209
+ self.attention_norm = RMSNorm(dim, eps=norm_eps)
210
+ self.ffn_norm = RMSNorm(dim, eps=norm_eps)
211
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
212
+
213
+ def forward(self, x: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None):
214
+ h = x + self.drop_path(self.attention(self.attention_norm(x), start_pos, mask))
215
+ out = h + self.drop_path(self.feed_forward(self.ffn_norm(h)))
216
+ return out
217
+
218
+
219
+ class Transformer(nn.Module):
220
+ def __init__(
221
+ self,
222
+ dim: int = 4096,
223
+ n_layer: int = 32,
224
+ n_head: int = 32,
225
+ attn_dropout_p: float = 0.0,
226
+ resid_dropout_p: float = 0.1,
227
+ ffn_dropout_p: float = 0.1,
228
+ drop_path_rate: float = 0.0,
229
+ num_classes: Union[int, List[int]] = 1000,
230
+ class_dropout_prob: float = 0.1,
231
+
232
+ cls_token_num: int = 1,
233
+ num_slots: int = 16,
234
+ slot_dim: int = 256,
235
+
236
+ diffloss_d: int = 3,
237
+ diffloss_w: int = 1024,
238
+ num_sampling_steps: str = '100',
239
+ diffusion_batch_mul: int = 4,
240
+ predict_xstart: bool = False,
241
+ use_si: bool = False,
242
+ deep_supervision: bool = False,
243
+ token_drop_prob: float = 0.0,
244
+ cond_method: str = "adaln",
245
+ decoupled_cfg: bool = True,
246
+ **kwargs,
247
+ ):
248
+ super().__init__()
249
+
250
+ # Store configuration
251
+ self.dim = dim
252
+ self.n_layer = n_layer
253
+ self.n_head = n_head
254
+ self.num_slots = num_slots
255
+ self.slot_dim = slot_dim
256
+ self.num_classes = num_classes
257
+ self.cls_token_num = cls_token_num
258
+
259
+ # Initialize embeddings
260
+ self.cls_embedding = LabelEmbedder(num_classes, dim, class_dropout_prob)
261
+ self.z_proj = nn.Linear(slot_dim, dim, bias=True)
262
+ self.z_proj_ln = RMSNorm(dim)
263
+ self.pos_embed_learned = nn.Parameter(torch.zeros(1, num_slots + cls_token_num, dim))
264
+
265
+ # transformer blocks
266
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layer)]
267
+ self.layers = torch.nn.ModuleList()
268
+ for layer_id in range(n_layer):
269
+ self.layers.append(TransformerBlock(
270
+ dim=dim,
271
+ n_head=n_head,
272
+ ffn_dropout_p=ffn_dropout_p,
273
+ attn_dropout_p=attn_dropout_p,
274
+ resid_dropout_p=resid_dropout_p,
275
+ drop_path=dpr[layer_id],
276
+ ))
277
+
278
+ # output layer
279
+ self.norm = RMSNorm(dim)
280
+
281
+ self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, num_slots, dim))
282
+
283
+ # KVCache
284
+ self.max_batch_size = -1
285
+ self.max_seq_length = -1
286
+
287
+ self.initialize_weights()
288
+
289
+ # Diffusion Loss
290
+ self.diffloss = DiffLoss(
291
+ target_channels=slot_dim,
292
+ z_channels=dim,
293
+ width=diffloss_w,
294
+ depth=diffloss_d,
295
+ num_sampling_steps=num_sampling_steps,
296
+ predict_xstart=predict_xstart,
297
+ use_si=use_si,
298
+ deep_supervision=deep_supervision,
299
+ token_drop_prob=token_drop_prob,
300
+ cond_method=cond_method,
301
+ decoupled_cfg=decoupled_cfg,
302
+ )
303
+ self.decoupled_cfg = decoupled_cfg
304
+ self.diffusion_batch_mul = diffusion_batch_mul
305
+
306
+ def initialize_weights(self):
307
+ nn.init.normal_(self.pos_embed_learned, std=0.02)
308
+ nn.init.normal_(self.diffusion_pos_embed_learned, std=0.02)
309
+ # Initialize nn.Linear and nn.Embedding
310
+ self.apply(self._init_weights)
311
+
312
+ def _init_weights(self, module):
313
+ if isinstance(module, nn.Linear):
314
+ module.weight.data.normal_(std=0.02)
315
+ if module.bias is not None:
316
+ module.bias.data.zero_()
317
+ elif isinstance(module, nn.Embedding):
318
+ module.weight.data.normal_(std=0.02)
319
+
320
+ def setup_caches(self, max_batch_size, max_seq_length, dtype):
321
+ # if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
322
+ # return
323
+ head_dim = self.dim // self.n_head
324
+ max_seq_length = find_multiple(max_seq_length, 8)
325
+ self.max_seq_length = max_seq_length
326
+ self.max_batch_size = max_batch_size
327
+ for b in self.layers:
328
+ b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.n_head, head_dim, dtype)
329
+
330
+ causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
331
+ self.causal_mask = causal_mask.unsqueeze(0).repeat(self.max_batch_size, 1, 1)
332
+
333
+ def reset_caches(self):
334
+ self.max_seq_length = -1
335
+ self.max_batch_size = -1
336
+ for b in self.layers:
337
+ b.attention.kv_cache = None
338
+
339
+ def forward_loss(self, z, target):
340
+ bsz, seq_len, _ = target.shape
341
+ target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
342
+ z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1)
343
+ loss = self.diffloss(z=z, target=target)
344
+ return loss
345
+
346
+ def forward_cfg(self, h, cfg):
347
+ if cfg > 1.0:
348
+ h_cond, h_uncond = h.chunk(2, dim=0)
349
+ h = h_uncond + cfg * (h_cond - h_uncond)
350
+ return h
351
+
352
+ def forward(
353
+ self,
354
+ slots: torch.Tensor,
355
+ cond_idx: torch.Tensor, # cond_idx_or_embed
356
+ input_pos: Optional[torch.Tensor] = None,
357
+ mask: Optional[torch.Tensor] = None,
358
+ cfg: float = 1.0,
359
+ diff_cfg: float = 1.0,
360
+ temperature: float = 1.0
361
+ ):
362
+ if slots is not None and cond_idx is not None: # training or naive inference
363
+ cond_embeddings = self.cls_embedding(cond_idx, train=self.training)
364
+ cond_embeddings = cond_embeddings.expand(-1, self.cls_token_num, -1)
365
+ token_embeddings = self.z_proj(slots)
366
+ token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1)
367
+ else:
368
+ if cond_idx is not None: # prefill in inference
369
+ token_embeddings = self.cls_embedding(cond_idx, train=self.training)
370
+ token_embeddings = token_embeddings.expand(-1, self.cls_token_num, -1)
371
+ else: # decode_n_tokens(kv cache) in inference
372
+ token_embeddings = self.z_proj(slots)
373
+
374
+ bs = token_embeddings.shape[0]
375
+ mask = self.causal_mask[:bs, None, input_pos]
376
+
377
+ h = token_embeddings
378
+ if self.training:
379
+ h = h + self.pos_embed_learned
380
+ else:
381
+ h = h + self.pos_embed_learned[:, input_pos].view(1, -1, self.dim)
382
+
383
+ h = self.z_proj_ln(h) # not sure if this is needed
384
+
385
+ # transformer blocks
386
+ for layer in self.layers:
387
+ h = layer(h, input_pos, mask)
388
+
389
+ h = self.norm(h)
390
+
391
+ if self.training:
392
+ h = h[:, self.cls_token_num - 1 : -1].contiguous()
393
+ h = h + self.diffusion_pos_embed_learned
394
+ loss = self.forward_loss(h, slots.detach())
395
+ return loss
396
+ else:
397
+ if self.decoupled_cfg:
398
+ h = self.forward_cfg(h[:, -1], cfg)
399
+ h = h + self.diffusion_pos_embed_learned[:, input_pos[-1] - self.cls_token_num + 1]
400
+ if diff_cfg > 1.0 and hasattr(self.diffloss.net, 'null_token'):
401
+ null_token = self.diffloss.net.null_token.expand(h.shape[0], -1)
402
+ h = torch.cat((h, null_token), dim=0)
403
+ else:
404
+ diff_cfg = 1.0
405
+ next_tokens = self.diffloss.sample(h, temperature=temperature, cfg=diff_cfg)
406
+ else:
407
+ h = h[:, -1]
408
+ h = h + self.diffusion_pos_embed_learned[:, input_pos[-1] - self.cls_token_num + 1]
409
+ next_tokens = self.diffloss.sample(h, temperature=temperature, cfg=cfg)
410
+ return next_tokens
411
+
412
+
413
+ def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
414
+ return list(self.layers)
415
+
416
+
417
+
418
+ #################################################################################
419
+ # GPT Configs #
420
+ #################################################################################
421
+ ### text-conditional
422
+ def GPT_7B(**kwargs):
423
+ return Transformer(n_layer=32, n_head=32, dim=4096, **kwargs) # 6.6B
424
+
425
+ def GPT_3B(**kwargs):
426
+ return Transformer(n_layer=24, n_head=32, dim=3200, **kwargs) # 3.1B
427
+
428
+ def GPT_1B(**kwargs):
429
+ return Transformer(n_layer=22, n_head=32, dim=2048, **kwargs) # 1.2B
430
+
431
+ ### class-conditional
432
+ def GPT_XXXL(**kwargs):
433
+ return Transformer(n_layer=48, n_head=40, dim=2560, **kwargs) # 3.9B
434
+
435
+ def GPT_XXL(**kwargs):
436
+ return Transformer(n_layer=48, n_head=24, dim=1536, **kwargs) # 1.4B
437
+
438
+ def GPT_XL(**kwargs):
439
+ return Transformer(n_layer=36, n_head=20, dim=1280, **kwargs) # 775M
440
+
441
+ def GPT_L(**kwargs):
442
+ return Transformer(n_layer=24, n_head=16, dim=1024, **kwargs) # 343M
443
+
444
+ def GPT_B(**kwargs):
445
+ return Transformer(n_layer=12, n_head=12, dim=768, **kwargs) # 111M
446
+
447
+
448
+ GPT_models = {
449
+ 'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL,
450
+ 'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B,
451
+ }
paintmind/utils/__init__.py ADDED
File without changes
paintmind/utils/datasets.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torchvision
4
+ import numpy as np
5
+ import os.path as osp
6
+ from glob import glob
7
+ from PIL import Image
8
+ import torchvision
9
+ import torchvision.transforms as TF
10
+
11
+ def pair(t):
12
+ return t if isinstance(t, tuple) else (t, t)
13
+
14
+ def center_crop_arr(pil_image, image_size):
15
+ """
16
+ Center cropping implementation from ADM.
17
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
18
+ """
19
+ while min(*pil_image.size) >= 2 * image_size:
20
+ pil_image = pil_image.resize(
21
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
22
+ )
23
+
24
+ scale = image_size / min(*pil_image.size)
25
+ pil_image = pil_image.resize(
26
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
27
+ )
28
+
29
+ arr = np.array(pil_image)
30
+ crop_y = (arr.shape[0] - image_size) // 2
31
+ crop_x = (arr.shape[1] - image_size) // 2
32
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
33
+
34
+ def vae_transforms(image_set, aug='randcrop', img_size=256):
35
+
36
+ t = []
37
+ if image_set == 'train':
38
+ if aug == 'randcrop':
39
+ t.append(TF.Resize(img_size, interpolation=TF.InterpolationMode.BICUBIC, antialias=True))
40
+ t.append(TF.RandomCrop(img_size))
41
+ elif aug == 'centercrop':
42
+ t.append(TF.Lambda(lambda x: center_crop_arr(x, img_size)))
43
+ else:
44
+ raise ValueError(f"Invalid augmentation: {aug}")
45
+ t.append(TF.RandomHorizontalFlip(p=0.5))
46
+ else:
47
+ t.append(TF.Resize(img_size, interpolation=TF.InterpolationMode.BICUBIC, antialias=True))
48
+ t.append(TF.CenterCrop(img_size))
49
+
50
+ t.append(TF.ToTensor())
51
+
52
+ return TF.Compose(t)
53
+
54
+
55
+ def cached_transforms(aug='tencrop', img_size=256, crop_ranges=[1.05, 1.10]):
56
+ t = []
57
+ if 'centercrop' in aug:
58
+ t.append(TF.Lambda(lambda x: center_crop_arr(x, img_size)))
59
+ t.append(TF.Lambda(lambda x: torch.stack([TF.ToTensor()(x), TF.ToTensor()(TF.functional.hflip(x))])))
60
+ elif 'tencrop' in aug:
61
+ crop_sizes = [int(img_size * crop_range) for crop_range in crop_ranges]
62
+ t.append(TF.Lambda(lambda x: [center_crop_arr(x, crop_size) for crop_size in crop_sizes]))
63
+ t.append(TF.Lambda(lambda crops: [crop for crop_tuple in [TF.TenCrop(img_size)(crop) for crop in crops] for crop in crop_tuple]))
64
+ t.append(TF.Lambda(lambda crops: torch.stack([TF.ToTensor()(crop) for crop in crops])))
65
+ else:
66
+ raise ValueError(f"Invalid augmentation: {aug}")
67
+
68
+ return TF.Compose(t)
69
+
70
+
71
+ class ImageNet(torchvision.datasets.ImageFolder):
72
+ def __init__(self, root, split='train', aug='randcrop', img_size=256):
73
+ super().__init__(osp.join(root, split))
74
+ if not 'cache' in aug:
75
+ self.transform = vae_transforms(split, aug=aug, img_size=img_size)
76
+ else:
77
+ self.transform = cached_transforms(aug=aug, img_size=img_size)
paintmind/utils/device_utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import torch
3
+ import importlib.util
4
+
5
+ def configure_compute_backend():
6
+ """Configure PyTorch compute backend settings for CUDA."""
7
+ if torch.cuda.is_available():
8
+ torch.backends.cuda.matmul.allow_tf32 = True
9
+ torch.backends.cudnn.allow_tf32 = True
10
+ torch.backends.cudnn.benchmark = True
11
+ torch.backends.cudnn.deterministic = False
12
+ else:
13
+ raise ValueError("No CUDA available")
14
+
15
+ def get_device():
16
+ """Get the device to use for training."""
17
+ if torch.cuda.is_available():
18
+ return torch.device("cuda")
19
+ else:
20
+ raise ValueError("No CUDA available")
paintmind/utils/logger.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict, deque
2
+ import datetime
3
+ import time
4
+ import torch
5
+ import torch.distributed as dist
6
+ from paintmind.engine.misc import is_dist_avail_and_initialized, is_main_process
7
+ from paintmind.utils.device_utils import get_device
8
+
9
+ def synchronize_processes():
10
+ if torch.cuda.is_available():
11
+ torch.cuda.synchronize()
12
+ else: # do nothing
13
+ pass
14
+
15
+ def empty_cache():
16
+ if torch.cuda.is_available():
17
+ torch.cuda.empty_cache()
18
+ else: # do nothing
19
+ pass
20
+
21
+ class SmoothedValue(object):
22
+ """Track a series of values and provide access to smoothed values over a
23
+ window or the global series average.
24
+ """
25
+
26
+ def __init__(self, window_size=20, fmt=None):
27
+ if fmt is None:
28
+ fmt = "{median:.4f} ({global_avg:.4f})"
29
+ self.deque = deque(maxlen=window_size)
30
+ self.total = 0.0
31
+ self.count = 0
32
+ self.fmt = fmt
33
+
34
+ def update(self, value, n=1):
35
+ self.deque.append(value)
36
+ self.count += n
37
+ self.total += value * n
38
+
39
+ def synchronize_between_processes(self):
40
+ """
41
+ Warning: does not synchronize the deque!
42
+ """
43
+ if not is_dist_avail_and_initialized():
44
+ return
45
+ t = torch.tensor([self.count, self.total], dtype=torch.float32, device=get_device())
46
+ dist.barrier()
47
+ dist.all_reduce(t)
48
+ t = t.tolist()
49
+ self.count = int(t[0])
50
+ self.total = t[1]
51
+
52
+ @property
53
+ def median(self):
54
+ d = torch.tensor(list(self.deque))
55
+ return d.median().item()
56
+
57
+ @property
58
+ def avg(self):
59
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
60
+ return d.mean().item()
61
+
62
+ @property
63
+ def global_avg(self):
64
+ return self.total / self.count
65
+
66
+ @property
67
+ def max(self):
68
+ return max(self.deque)
69
+
70
+ @property
71
+ def value(self):
72
+ return self.deque[-1]
73
+
74
+ def __str__(self):
75
+ return self.fmt.format(
76
+ median=self.median,
77
+ avg=self.avg,
78
+ global_avg=self.global_avg,
79
+ max=self.max,
80
+ value=self.value)
81
+
82
+
83
+ class MetricLogger(object):
84
+ def __init__(self, delimiter="\t"):
85
+ self.meters = defaultdict(SmoothedValue)
86
+ self.delimiter = delimiter
87
+
88
+ def update(self, **kwargs):
89
+ for k, v in kwargs.items():
90
+ if v is None:
91
+ continue
92
+ if isinstance(v, torch.Tensor):
93
+ v = v.item()
94
+ assert isinstance(v, (float, int))
95
+ self.meters[k].update(v)
96
+
97
+ def __getattr__(self, attr):
98
+ if attr in self.meters:
99
+ return self.meters[attr]
100
+ if attr in self.__dict__:
101
+ return self.__dict__[attr]
102
+ raise AttributeError("'{}' object has no attribute '{}'".format(
103
+ type(self).__name__, attr))
104
+
105
+ def __str__(self):
106
+ loss_str = []
107
+ for name, meter in self.meters.items():
108
+ loss_str.append(
109
+ "{}: {}".format(name, str(meter))
110
+ )
111
+ return self.delimiter.join(loss_str)
112
+
113
+ def synchronize_between_processes(self):
114
+ for meter in self.meters.values():
115
+ meter.synchronize_between_processes()
116
+
117
+ def add_meter(self, name, meter):
118
+ self.meters[name] = meter
119
+
120
+ def log_every(self, iterable, print_freq, header=None):
121
+ i = 0
122
+ if not header:
123
+ header = ''
124
+ start_time = time.time()
125
+ end = time.time()
126
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
127
+ data_time = SmoothedValue(fmt='{avg:.4f}')
128
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
129
+ log_msg = [
130
+ header,
131
+ '[{0' + space_fmt + '}/{1}]',
132
+ 'eta: {eta}',
133
+ '{meters}',
134
+ 'time: {time}',
135
+ 'data: {data}'
136
+ ]
137
+ if torch.cuda.is_available():
138
+ log_msg.append('mem: {memory:.0f}')
139
+ log_msg.append("util: {util:.1f}%")
140
+ log_msg = self.delimiter.join(log_msg)
141
+ MB = 1024.0 * 1024.0
142
+ for obj in iterable:
143
+ data_time.update(time.time() - end)
144
+ yield obj
145
+ iter_time.update(time.time() - end)
146
+ if i % print_freq == 0 or i == len(iterable) - 1:
147
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
148
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
149
+ if torch.cuda.is_available():
150
+ if is_main_process():
151
+ memory = torch.cuda.max_memory_allocated()
152
+ util = torch.cuda.utilization()
153
+ print(log_msg.format(
154
+ i, len(iterable), eta=eta_string,
155
+ meters=str(self),
156
+ time=str(iter_time), data=str(data_time),
157
+ memory=memory / MB, util=util))
158
+ else:
159
+ if is_main_process():
160
+ print(log_msg.format(
161
+ i, len(iterable), eta=eta_string,
162
+ meters=str(self),
163
+ time=str(iter_time), data=str(data_time)))
164
+ i += 1
165
+ end = time.time()
166
+ total_time = time.time() - start_time
167
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
168
+ if is_main_process():
169
+ print('{} Total time: {} ({:.4f} s / it)'.format(
170
+ header, total_time_str, total_time / len(iterable)))
paintmind/utils/lr_scheduler.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from timm.scheduler.cosine_lr import CosineLRScheduler
2
+ from timm.scheduler.step_lr import StepLRScheduler
3
+
4
+ def build_scheduler(optimizer, n_epoch, n_iter_per_epoch, lr_min=0, warmup_steps=0, warmup_lr_init=0, decay_steps=None, cosine_lr=True):
5
+ if decay_steps is None:
6
+ decay_steps = n_epoch * n_iter_per_epoch
7
+
8
+ if cosine_lr:
9
+ scheduler = CosineLRScheduler(optimizer, t_initial=decay_steps, lr_min=lr_min, warmup_t=warmup_steps, warmup_lr_init=warmup_lr_init,
10
+ cycle_limit=1, t_in_epochs=False, warmup_prefix=True)
11
+ else:
12
+ scheduler = StepLRScheduler(optimizer, decay_t=decay_steps, warmup_t=warmup_steps, warmup_lr_init=warmup_lr_init,
13
+ t_in_epochs=False, warmup_prefix=True)
14
+
15
+ return scheduler
paintmind/utils/transform.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import torchvision.transforms as T
3
+
4
+ def pair(t):
5
+ return t if isinstance(t, tuple) else (t, t)
6
+
7
+ def stage1_transform(img_size=256, is_train=True, scale=0.8):
8
+
9
+ resize = pair(int(img_size/scale))
10
+ t = []
11
+ t.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC))
12
+ if is_train:
13
+ t.append(T.RandomCrop(img_size))
14
+ t.append(T.RandomHorizontalFlip(p=0.5))
15
+ else:
16
+ t.append(T.CenterCrop(img_size))
17
+
18
+ t.append(T.ToTensor())
19
+ t.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))),
20
+
21
+ return T.Compose(t)
22
+
23
+ def stage2_transform(img_size=256, is_train=True, scale=0.8):
24
+ resize = pair(int(img_size/scale))
25
+ t = []
26
+ t.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC))
27
+ if is_train:
28
+ t.append(T.RandomCrop(img_size))
29
+ else:
30
+ t.append(T.CenterCrop(img_size))
31
+
32
+ t.append(T.ToTensor())
33
+ t.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))),
34
+
35
+ return T.Compose(t)
paintmind/version.py ADDED
@@ -0,0 +1 @@
 
 
1
+ __version__ = '0.0.0'
requirements.txt ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.26.4
2
+ sympy>=1.10
3
+ accelerate
4
+ datasets
5
+ diffusers[torch]
6
+ transformers
7
+ safetensors
8
+ smart_open
9
+ dotwiz
10
+ omegaconf
11
+ tensorboard
12
+ huggingface-hub
13
+ einops
14
+ lpips
15
+ timm
16
+ scipy
17
+ scikit-learn
18
+ scikit-image
19
+ kornia
20
+ torchtyping
21
+ git+https://github.com/xwen99/torch-fidelity.git@master#egg=torch-fidelity
22
+ open_clip_torch
23
+ opencv-python-headless
24
+ torchmetrics
25
+ torchdiffeq
26
+ lmdb
27
+ triton==3.0.0
submitit_test.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # A script to run multinode training with submitit.
8
+ # --------------------------------------------------------
9
+
10
+ import argparse
11
+ import os.path as osp
12
+ import submitit
13
+ import itertools
14
+
15
+ from omegaconf import OmegaConf
16
+ from paintmind.engine.util import instantiate_from_config
17
+ from paintmind.utils.device_utils import configure_compute_backend
18
+
19
+
20
+ def parse_args():
21
+ parser = argparse.ArgumentParser("Submitit for accelerator training")
22
+ # Slurm configuration
23
+ parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
24
+ parser.add_argument("--nodes", default=1, type=int, help="Number of nodes to request")
25
+ parser.add_argument("--timeout", default=7000, type=int, help="Duration of the job, default 5 days")
26
+ parser.add_argument("--qos", default="normal", type=str, help="QOS to request")
27
+ parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.")
28
+ parser.add_argument("--partition", default="h100-camera-train", type=str, help="Partition where to submit")
29
+ parser.add_argument("--exclude", default="", type=str, help="Exclude nodes from the partition")
30
+ parser.add_argument("--nodelist", default="", type=str, help="Nodelist to request")
31
+ parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler")
32
+
33
+ # Model and testing configuration
34
+ parser.add_argument('--model', type=str, nargs='+', default=[None], help="Path to model(s)")
35
+ parser.add_argument('--step', type=int, nargs='+', default=[250000], help="Step number(s)")
36
+ parser.add_argument('--cfg', type=str, default=None, help="Path to config file")
37
+ parser.add_argument('--dataset', type=str, default='imagenet', help="Dataset to use")
38
+
39
+ # Legacy parameter (preserved for backward compatibility)
40
+ parser.add_argument('--cfg_value', type=float, nargs='+', default=[None],
41
+ help='Legacy parameter for GPT classifier-free guidance scale')
42
+
43
+ # CFG-related parameters - all with nargs='+' to support multiple values
44
+ parser.add_argument('--ae_cfg', type=float, nargs='+', default=[None],
45
+ help="Autoencoder classifier-free guidance scale")
46
+ parser.add_argument('--diff_cfg', type=float, nargs='+', default=[None],
47
+ help="Diffusion classifier-free guidance scale")
48
+ parser.add_argument('--cfg_schedule', type=str, nargs='+', default=[None],
49
+ help="CFG schedule type (e.g., constant, linear)")
50
+ parser.add_argument('--diff_cfg_schedule', type=str, nargs='+', default=[None],
51
+ help="Diffusion CFG schedule type (e.g., constant, inv_linear)")
52
+ parser.add_argument('--test_num_slots', type=int, nargs='+', default=[None],
53
+ help="Number of slots to use for inference")
54
+ parser.add_argument('--temperature', type=float, nargs='+', default=[None],
55
+ help="Temperature for sampling")
56
+
57
+ return parser.parse_args()
58
+
59
+
60
+ def load_config(model_path, cfg_path=None):
61
+ """Load configuration from file or model directory."""
62
+ if cfg_path is not None and osp.exists(cfg_path):
63
+ config_path = cfg_path
64
+ elif model_path and osp.exists(osp.join(model_path, 'config.yaml')):
65
+ config_path = osp.join(model_path, 'config.yaml')
66
+ else:
67
+ raise ValueError(f"No config file found at {model_path} or {cfg_path}")
68
+
69
+ return OmegaConf.load(config_path)
70
+
71
+
72
+ def setup_checkpoint_path(model_path, step, config):
73
+ """Set up the checkpoint path based on model and step."""
74
+ if model_path:
75
+ ckpt_path = osp.join(model_path, 'models', f'step{step}')
76
+ if not osp.exists(ckpt_path):
77
+ print(f"Skipping non-existent checkpoint: {ckpt_path}")
78
+ return None
79
+ if hasattr(config.trainer.params, 'model'):
80
+ config.trainer.params.model.params.ckpt_path = ckpt_path
81
+ else:
82
+ config.trainer.params.gpt_model.params.ckpt_path = ckpt_path
83
+ else:
84
+ result_folder = config.trainer.params.result_folder
85
+ ckpt_path = osp.join(result_folder, 'models', f'step{step}')
86
+ if hasattr(config.trainer.params, 'model'):
87
+ config.trainer.params.model.params.ckpt_path = ckpt_path
88
+ else:
89
+ config.trainer.params.gpt_model.params.ckpt_path = ckpt_path
90
+
91
+ return ckpt_path
92
+
93
+
94
+ def setup_test_config(config, use_coco=False):
95
+ """Set up common test configuration parameters."""
96
+ config.trainer.params.test_dataset = config.trainer.params.dataset
97
+ if not use_coco:
98
+ config.trainer.params.test_dataset.params.split = 'val'
99
+ else:
100
+ config.trainer.params.test_dataset.target = 'paintmind.utils.datasets.COCO'
101
+ config.trainer.params.test_dataset.params.root = './dataset/coco'
102
+ config.trainer.params.test_dataset.params.split = 'val2017'
103
+ config.trainer.params.test_only = True
104
+ config.trainer.params.compile = False
105
+ config.trainer.params.eval_fid = True
106
+ config.trainer.params.fid_stats = 'fid_stats/adm_in256_stats.npz'
107
+ if hasattr(config.trainer.params, 'model'):
108
+ config.trainer.params.model.params.num_sampling_steps = '250'
109
+ else:
110
+ config.trainer.params.ae_model.params.num_sampling_steps = '250'
111
+
112
+ def apply_cfg_params(config, param_dict):
113
+ """Apply CFG-related parameters to the config."""
114
+ # Apply each parameter if it's not None
115
+ if param_dict.get('cfg_value') is not None:
116
+ config.trainer.params.cfg = param_dict['cfg_value']
117
+ print(f"Setting cfg to {param_dict['cfg_value']}")
118
+
119
+ if param_dict.get('ae_cfg') is not None:
120
+ config.trainer.params.ae_cfg = param_dict['ae_cfg']
121
+ print(f"Setting ae_cfg to {param_dict['ae_cfg']}")
122
+
123
+ if param_dict.get('diff_cfg') is not None:
124
+ config.trainer.params.diff_cfg = param_dict['diff_cfg']
125
+ print(f"Setting diff_cfg to {param_dict['diff_cfg']}")
126
+
127
+ if param_dict.get('cfg_schedule') is not None:
128
+ config.trainer.params.cfg_schedule = param_dict['cfg_schedule']
129
+ print(f"Setting cfg_schedule to {param_dict['cfg_schedule']}")
130
+
131
+ if param_dict.get('diff_cfg_schedule') is not None:
132
+ config.trainer.params.diff_cfg_schedule = param_dict['diff_cfg_schedule']
133
+ print(f"Setting diff_cfg_schedule to {param_dict['diff_cfg_schedule']}")
134
+
135
+ if param_dict.get('test_num_slots') is not None:
136
+ config.trainer.params.test_num_slots = param_dict['test_num_slots']
137
+ print(f"Setting test_num_slots to {param_dict['test_num_slots']}")
138
+
139
+ if param_dict.get('temperature') is not None:
140
+ config.trainer.params.temperature = param_dict['temperature']
141
+ print(f"Setting temperature to {param_dict['temperature']}")
142
+
143
+
144
+ def run_test(config):
145
+ """Instantiate trainer and run test."""
146
+ trainer = instantiate_from_config(config.trainer)
147
+ trainer.train()
148
+
149
+
150
+ def generate_param_combinations(args):
151
+ """Generate all combinations of parameters from the provided arguments."""
152
+ # Create parameter grid for all combinations
153
+ param_grid = {
154
+ 'cfg_value': [None] if args.cfg_value == [None] else args.cfg_value,
155
+ 'ae_cfg': [None] if args.ae_cfg == [None] else args.ae_cfg,
156
+ 'diff_cfg': [None] if args.diff_cfg == [None] else args.diff_cfg,
157
+ 'cfg_schedule': [None] if args.cfg_schedule == [None] else args.cfg_schedule,
158
+ 'diff_cfg_schedule': [None] if args.diff_cfg_schedule == [None] else args.diff_cfg_schedule,
159
+ 'test_num_slots': [None] if args.test_num_slots == [None] else args.test_num_slots,
160
+ 'temperature': [None] if args.temperature == [None] else args.temperature
161
+ }
162
+
163
+ # Get all parameter names that have non-None values
164
+ active_params = [k for k, v in param_grid.items() if v != [None]]
165
+
166
+ if not active_params:
167
+ # If no parameters are specified, yield a dict with all None values
168
+ yield {k: None for k in param_grid.keys()}
169
+ return
170
+
171
+ # Generate all combinations of active parameters
172
+ active_values = [param_grid[k] for k in active_params]
173
+ for combination in itertools.product(*active_values):
174
+ param_dict = {k: None for k in param_grid.keys()} # Start with all None
175
+ for i, param_name in enumerate(active_params):
176
+ param_dict[param_name] = combination[i]
177
+ yield param_dict
178
+
179
+
180
+ class Trainer(object):
181
+ def __init__(self, args):
182
+ self.args = args
183
+
184
+ def __call__(self):
185
+ """Main entry point for the submitit job."""
186
+ self._setup_gpu_args()
187
+ configure_compute_backend()
188
+ self._run_tests()
189
+
190
+ def _run_tests(self):
191
+ """Run tests for all specified models and steps."""
192
+ for step in self.args.step:
193
+ for model in self.args.model:
194
+ print(f"Testing model: {model} at step: {step}")
195
+
196
+ # Load configuration
197
+ config = load_config(model, self.args.cfg)
198
+
199
+ # Setup checkpoint path
200
+ ckpt_path = setup_checkpoint_path(model, step, config)
201
+ if ckpt_path is None:
202
+ continue
203
+
204
+ use_coco = self.args.dataset == 'coco' or self.args.dataset == 'COCO'
205
+ # Setup test configuration
206
+ setup_test_config(config, use_coco)
207
+
208
+ # Generate and apply all parameter combinations
209
+ for param_dict in generate_param_combinations(self.args):
210
+ # Create a copy of the config for each parameter combination
211
+ current_config = OmegaConf.create(OmegaConf.to_container(config, resolve=True))
212
+
213
+ # Print parameter combination
214
+ param_str = ", ".join([f"{k}={v}" for k, v in param_dict.items() if v is not None])
215
+ print(f"Testing with parameters: {param_str}")
216
+
217
+ # Apply parameters and run test
218
+ apply_cfg_params(current_config, param_dict)
219
+ run_test(current_config)
220
+
221
+ def _setup_gpu_args(self):
222
+ """Set up GPU and distributed environment variables."""
223
+ import submitit
224
+
225
+ print("Exporting PyTorch distributed environment variables")
226
+ dist_env = submitit.helpers.TorchDistributedEnvironment().export(set_cuda_visible_devices=False)
227
+ print(f"Master: {dist_env.master_addr}:{dist_env.master_port}")
228
+ print(f"Rank: {dist_env.rank}")
229
+ print(f"World size: {dist_env.world_size}")
230
+ print(f"Local rank: {dist_env.local_rank}")
231
+ print(f"Local world size: {dist_env.local_world_size}")
232
+
233
+ job_env = submitit.JobEnvironment()
234
+ self.args.output_dir = str(self.args.output_dir).replace("%j", str(job_env.job_id))
235
+ self.args.log_dir = self.args.output_dir
236
+ print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
237
+
238
+
239
+ def main():
240
+ """Main function to set up and submit the job."""
241
+ args = parse_args()
242
+
243
+ # Determine job directory
244
+ if args.cfg is not None and osp.exists(args.cfg):
245
+ config = OmegaConf.load(args.cfg)
246
+ elif osp.exists(osp.join(args.model[0], 'config.yaml')):
247
+ config = OmegaConf.load(osp.join(args.model[0], 'config.yaml'))
248
+ else:
249
+ raise ValueError(f"No config file found at {args.model[0]} or {args.cfg}")
250
+
251
+ args.job_dir = config.trainer.params.result_folder
252
+
253
+ # Set up the executor
254
+ executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30)
255
+
256
+ # Configure slurm parameters
257
+ slurm_kwargs = {
258
+ 'slurm_signal_delay_s': 120,
259
+ 'slurm_qos': args.qos
260
+ }
261
+
262
+ if args.comment:
263
+ slurm_kwargs['slurm_comment'] = args.comment
264
+ if args.exclude:
265
+ slurm_kwargs['slurm_exclude'] = args.exclude
266
+ if args.nodelist:
267
+ slurm_kwargs['slurm_nodelist'] = args.nodelist
268
+
269
+ # Update executor parameters
270
+ executor.update_parameters(
271
+ gpus_per_node=args.ngpus,
272
+ tasks_per_node=args.ngpus, # one task per GPU
273
+ nodes=args.nodes,
274
+ timeout_min=args.timeout,
275
+ slurm_partition=args.partition,
276
+ name="fid",
277
+ **slurm_kwargs
278
+ )
279
+
280
+ args.output_dir = args.job_dir
281
+
282
+ # Submit the job
283
+ trainer = Trainer(args)
284
+ job = executor.submit(trainer)
285
+
286
+ print("Submitted job_id:", job.job_id)
287
+
288
+
289
+ if __name__ == "__main__":
290
+ main()
submitit_train.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # A script to run multinode training with submitit.
8
+ # --------------------------------------------------------
9
+
10
+ import argparse
11
+ import os
12
+ import submitit
13
+
14
+ from omegaconf import OmegaConf
15
+ from paintmind.engine.util import instantiate_from_config
16
+ from paintmind.utils.device_utils import configure_compute_backend
17
+
18
+ def parse_args():
19
+ parser = argparse.ArgumentParser("Submitit for accelerator training")
20
+ parser.add_argument("--ngpus", default=8, type=int, help="Number of gpus to request on each node")
21
+ parser.add_argument("--nodes", default=2, type=int, help="Number of nodes to request")
22
+ parser.add_argument("--timeout", default=7000, type=int, help="Duration of the job, default 5 days")
23
+ parser.add_argument("--qos", default="normal", type=str, help="QOS to request")
24
+ parser.add_argument("--job_dir", default="", type=str, help="Job dir. Leave empty for automatic.")
25
+
26
+ parser.add_argument("--partition", default="h100-camera-train", type=str, help="Partition where to submit")
27
+ parser.add_argument("--exclude", default="", type=str, help="Exclude nodes from the partition")
28
+ parser.add_argument("--nodelist", default="", type=str, help="Nodelist to request")
29
+ parser.add_argument('--comment', default="", type=str, help="Comment to pass to scheduler")
30
+ parser.add_argument('--cfg', type=str, default='configs/dit_imagenet_400ep.yaml', help='accelerator configs')
31
+ return parser.parse_args()
32
+
33
+
34
+ class Trainer(object):
35
+ def __init__(self, args, config):
36
+ self.args = args
37
+ self.config = config
38
+
39
+ def __call__(self):
40
+ self._setup_gpu_args()
41
+ configure_compute_backend()
42
+ trainer = instantiate_from_config(self.config.trainer)
43
+ trainer.train(self.config)
44
+
45
+ def checkpoint(self):
46
+ import os
47
+ import submitit
48
+
49
+ model_dir = os.path.join(self.args.output_dir, "models")
50
+ if os.path.exists(model_dir):
51
+ # Get all step folders
52
+ step_folders = [d for d in os.listdir(model_dir) if d.startswith("step")]
53
+ if step_folders:
54
+ # Extract step numbers and find max
55
+ steps = [int(f.replace("step", "")) for f in step_folders]
56
+ max_step = max(steps)
57
+ # Set ckpt path to the latest step folder
58
+ self.config.trainer.params.model.params.ckpt_path = os.path.join(model_dir, f"step{max_step}")
59
+ print("Requeuing ", self.args, self.config)
60
+ empty_trainer = type(self)(self.args, self.config)
61
+ return submitit.helpers.DelayedSubmission(empty_trainer)
62
+
63
+ def _setup_gpu_args(self):
64
+ import submitit
65
+
66
+ # print_env()
67
+ print("exporting PyTorch distributed environment variables")
68
+ dist_env = submitit.helpers.TorchDistributedEnvironment().export(set_cuda_visible_devices=False)
69
+ print(f"master: {dist_env.master_addr}:{dist_env.master_port}")
70
+ print(f"rank: {dist_env.rank}")
71
+ print(f"world size: {dist_env.world_size}")
72
+ print(f"local rank: {dist_env.local_rank}")
73
+ print(f"local world size: {dist_env.local_world_size}")
74
+ # print_env()
75
+
76
+ # os.environ["NCCL_DEBUG"] = "INFO"
77
+ os.environ["NCCL_P2P_DISABLE"] = "0"
78
+ os.environ["NCCL_IB_DISABLE"] = "0"
79
+
80
+ job_env = submitit.JobEnvironment()
81
+ self.args.output_dir = str(self.args.output_dir).replace("%j", str(job_env.job_id))
82
+ self.args.log_dir = self.args.output_dir
83
+ self.config.trainer.params.result_folder = self.args.output_dir
84
+ self.config.trainer.params.log_dir = os.path.join(self.args.output_dir, "logs")
85
+ # self.args.gpu = job_env.local_rank
86
+ # self.args.rank = job_env.global_rank
87
+ # self.args.world_size = job_env.num_tasks
88
+ print(f"Process group: {job_env.num_tasks} tasks, rank: {job_env.global_rank}")
89
+
90
+
91
+ def main():
92
+ args = parse_args()
93
+ cfg_file = args.cfg
94
+ assert os.path.exists(cfg_file)
95
+ config = OmegaConf.load(cfg_file)
96
+
97
+ if config.trainer.params.result_folder is None:
98
+ if args.job_dir == "":
99
+ args.job_dir = "./output/%j"
100
+
101
+ config.trainer.params.result_folder = args.job_dir
102
+ config.trainer.params.log_dir = os.path.join(args.job_dir, "logs")
103
+ else:
104
+ args.job_dir = config.trainer.params.result_folder
105
+
106
+ # Note that the folder will depend on the job_id, to easily track experiments
107
+ executor = submitit.AutoExecutor(folder=args.job_dir, slurm_max_num_timeout=30)
108
+
109
+ num_gpus_per_node = args.ngpus
110
+ nodes = args.nodes
111
+ timeout_min = args.timeout
112
+ qos = args.qos
113
+
114
+ partition = args.partition
115
+ kwargs = {}
116
+ if args.comment:
117
+ kwargs['slurm_comment'] = args.comment
118
+ if args.exclude:
119
+ kwargs["slurm_exclude"] = args.exclude
120
+ if args.nodelist:
121
+ kwargs["slurm_nodelist"] = args.nodelist
122
+
123
+ executor.update_parameters(
124
+ mem_gb=40 * num_gpus_per_node,
125
+ gpus_per_node=num_gpus_per_node,
126
+ tasks_per_node=num_gpus_per_node, # one task per GPU
127
+ # cpus_per_task=16,
128
+ nodes=nodes,
129
+ timeout_min=timeout_min, # max is 60 * 72
130
+ # Below are cluster dependent parameters
131
+ slurm_partition=partition,
132
+ slurm_signal_delay_s=120,
133
+ slurm_qos=qos,
134
+ **kwargs
135
+ )
136
+
137
+ executor.update_parameters(name="sar")
138
+
139
+ args.output_dir = args.job_dir
140
+
141
+ trainer = Trainer(args, config)
142
+ job = executor.submit(trainer)
143
+
144
+ print("Submitted job_id:", job.job_id)
145
+
146
+
147
+ if __name__ == "__main__":
148
+ main()