tennant commited on
Commit
7b0a1ef
·
1 Parent(s): 448141b
Files changed (37) hide show
  1. README.md +1 -1
  2. configs/autoregressive_l.yaml +68 -0
  3. configs/autoregressive_xl.yaml +66 -0
  4. configs/onenode_config.yaml +11 -0
  5. configs/tokenizer_l.yaml +55 -0
  6. configs/tokenizer_xl.yaml +55 -0
  7. examples/city.jpg +0 -0
  8. examples/food.jpg +0 -0
  9. examples/highland.webp +0 -0
  10. gen_demo.py +262 -0
  11. requirements.txt +16 -0
  12. semanticist/engine/diffusion_trainer.py +488 -0
  13. semanticist/engine/gpt_trainer.py +694 -0
  14. semanticist/engine/trainer_utils.py +251 -0
  15. semanticist/stage1/diffuse_slot.py +452 -0
  16. semanticist/stage1/diffusion/__init__.py +46 -0
  17. semanticist/stage1/diffusion/diffusion_utils.py +88 -0
  18. semanticist/stage1/diffusion/gaussian_diffusion.py +886 -0
  19. semanticist/stage1/diffusion/respace.py +130 -0
  20. semanticist/stage1/diffusion/timestep_sampler.py +150 -0
  21. semanticist/stage1/diffusion_transfomer.py +372 -0
  22. semanticist/stage1/fused_attention.py +45 -0
  23. semanticist/stage1/pos_embed.py +102 -0
  24. semanticist/stage1/transport/__init__.py +63 -0
  25. semanticist/stage1/transport/integrators.py +130 -0
  26. semanticist/stage1/transport/path.py +192 -0
  27. semanticist/stage1/transport/transport.py +456 -0
  28. semanticist/stage1/transport/utils.py +29 -0
  29. semanticist/stage1/vision_transformer.py +259 -0
  30. semanticist/stage2/diffloss.py +267 -0
  31. semanticist/stage2/generate.py +88 -0
  32. semanticist/stage2/gpt.py +431 -0
  33. semanticist/utils/datasets.py +72 -0
  34. semanticist/utils/device_utils.py +18 -0
  35. semanticist/utils/logger.py +170 -0
  36. semanticist/utils/lr_scheduler.py +15 -0
  37. semanticist/utils/transform.py +35 -0
README.md CHANGED
@@ -5,7 +5,7 @@ colorFrom: purple
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.20.1
8
- app_file: app.py
9
  pinned: false
10
  license: mit
11
  ---
 
5
  colorTo: blue
6
  sdk: gradio
7
  sdk_version: 5.20.1
8
+ app_file: gen_demo.py
9
  pinned: false
10
  license: mit
11
  ---
configs/autoregressive_l.yaml ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ trainer:
2
+ target: semanticist.engine.gpt_trainer.GPTTrainer
3
+ params:
4
+ num_epoch: 400
5
+ blr: 1e-4
6
+ cosine_lr: False
7
+ warmup_epochs: 100
8
+ batch_size: 16
9
+ num_workers: 8
10
+ pin_memory: True
11
+ grad_accum_steps: 1
12
+ precision: 'bf16'
13
+ max_grad_norm: 1.0
14
+ enable_ema: True
15
+ save_every: 10000
16
+ sample_every: 5000
17
+ fid_every: 50000
18
+ eval_fid: False
19
+ result_folder: "./output/autoregressive"
20
+ log_dir: "./output/autoregressive/logs"
21
+ ae_cfg: 1.0
22
+ cfg: 6.0
23
+ cfg_schedule: "linear"
24
+ train_num_slots: 32
25
+ test_num_slots: 32
26
+ compile: True
27
+ enable_cache_latents: False
28
+ ae_model:
29
+ target: semanticist.stage1.diffuse_slot.DiffuseSlot
30
+ params:
31
+ encoder: 'vit_base_patch16'
32
+ enc_img_size: 256
33
+ enc_causal: True
34
+ num_slots: 256
35
+ slot_dim: 16
36
+ norm_slots: True
37
+ cond_method: 'token'
38
+ dit_model: 'DiT-L-2'
39
+ vae: 'xwen99/mar-vae-kl16'
40
+ num_sampling_steps: '250'
41
+ # ckpt_path: ./output/tokenizer/models_l/step250000/custom_checkpoint_1.pkl
42
+ ckpt_path: /mnt/ceph_rbd/mnt_pvc_vid_data/zbc/cache/semanticist_tok_L.pkl
43
+
44
+ gpt_model:
45
+ target: GPT-L
46
+ params:
47
+ num_slots: 32
48
+ slot_dim: 16
49
+ num_classes: 1000
50
+ cls_token_num: 1
51
+ resid_dropout_p: 0.1
52
+ ffn_dropout_p: 0.1
53
+ diffloss_d: 12
54
+ diffloss_w: 1536
55
+ num_sampling_steps: '100'
56
+ diffusion_batch_mul: 4
57
+ use_si: True
58
+ cond_method: 'concat'
59
+ ckpt_path: None
60
+
61
+ dataset:
62
+ target: semanticist.utils.datasets.ImageNet
63
+ params:
64
+ root: ./dataset/imagenet/
65
+ split: train
66
+ # aug: tencrop_cached # or centercrop_cached
67
+ aug: randcrop
68
+ img_size: 256
configs/autoregressive_xl.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ trainer:
2
+ target: semanticist.engine.gpt_trainer.GPTTrainer
3
+ params:
4
+ num_epoch: 400
5
+ blr: 1e-4
6
+ cosine_lr: False
7
+ warmup_epochs: 100
8
+ batch_size: 256
9
+ num_workers: 8
10
+ pin_memory: True
11
+ grad_accum_steps: 1
12
+ precision: 'bf16'
13
+ max_grad_norm: 1.0
14
+ enable_ema: True
15
+ save_every: 10000
16
+ sample_every: 5000
17
+ fid_every: 50000
18
+ eval_fid: False
19
+ result_folder: "./output/autoregressive"
20
+ log_dir: "./output/autoregressive/logs"
21
+ ae_cfg: 1.0
22
+ cfg: 5.0
23
+ cfg_schedule: "linear"
24
+ train_num_slots: 32
25
+ test_num_slots: 32
26
+ compile: True
27
+ enable_cache_latents: True
28
+ ae_model:
29
+ target: semanticist.stage1.diffuse_slot.DiffuseSlot
30
+ params:
31
+ encoder: 'vit_base_patch16'
32
+ enc_img_size: 256
33
+ enc_causal: True
34
+ num_slots: 256
35
+ slot_dim: 16
36
+ norm_slots: True
37
+ cond_method: 'token'
38
+ dit_model: 'DiT-XL-2'
39
+ vae: 'xwen99/mar-vae-kl16'
40
+ num_sampling_steps: '250'
41
+ ckpt_path: ./output/tokenizer/models_xl/step250000/custom_checkpoint_1.pkl
42
+
43
+ gpt_model:
44
+ target: GPT-L
45
+ params:
46
+ num_slots: 32
47
+ slot_dim: 16
48
+ num_classes: 1000
49
+ cls_token_num: 1
50
+ resid_dropout_p: 0.1
51
+ ffn_dropout_p: 0.1
52
+ diffloss_d: 12
53
+ diffloss_w: 1536
54
+ num_sampling_steps: '100'
55
+ diffusion_batch_mul: 4
56
+ use_si: True
57
+ cond_method: 'concat'
58
+ ckpt_path: None
59
+
60
+ dataset:
61
+ target: semanticist.utils.datasets.ImageNet
62
+ params:
63
+ root: ./dataset/imagenet/
64
+ split: train
65
+ aug: tencrop_cached # or centercrop_cached
66
+ img_size: 256
configs/onenode_config.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ deepspeed_config: {}
3
+ distributed_type: MULTI_GPU
4
+ fsdp_config: {}
5
+ machine_rank: 0
6
+ main_process_ip: null
7
+ main_process_port: null
8
+ main_training_function: main
9
+ num_machines: 1
10
+ num_processes: 1
11
+ use_cpu: false
configs/tokenizer_l.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ trainer:
2
+ target: semanticist.engine.diffusion_trainer.DiffusionTrainer
3
+ params:
4
+ num_epoch: 400
5
+ valid_size: 64
6
+ blr: 2.5e-5
7
+ cosine_lr: True
8
+ warmup_epochs: 1
9
+ batch_size: 64
10
+ num_workers: 16
11
+ pin_memory: True
12
+ grad_accum_steps: 1
13
+ precision: 'bf16'
14
+ max_grad_norm: 3.0
15
+ enable_ema: True
16
+ save_every: 10000
17
+ sample_every: 5000
18
+ fid_every: 50000
19
+ result_folder: "./output/tokenizer/models_l"
20
+ log_dir: "./output/tokenizer/models_l/logs"
21
+ cfg: 3.0
22
+ compile: True
23
+ model:
24
+ target: semanticist.stage1.diffuse_slot.DiffuseSlot
25
+ params:
26
+ encoder: 'vit_base_patch16'
27
+ enc_img_size: 256
28
+ enc_causal: True
29
+ enc_use_mlp: False
30
+ num_slots: 256
31
+ slot_dim: 16
32
+ norm_slots: True
33
+ dit_model: 'DiT-L-2'
34
+ vae: 'xwen99/mar-vae-kl16'
35
+ enable_nest: False
36
+ enable_nest_after: 50
37
+ use_repa: True
38
+ eval_fid: True
39
+ fid_stats: 'fid_stats/adm_in256_stats.npz'
40
+ num_sampling_steps: '250'
41
+ ckpt_path: None
42
+
43
+ dataset:
44
+ target: semanticist.utils.datasets.ImageNet
45
+ params:
46
+ root: ./dataset/imagenet/
47
+ split: train
48
+ img_size: 256
49
+
50
+ test_dataset:
51
+ target: semanticist.utils.datasets.ImageNet
52
+ params:
53
+ root: ./dataset/imagenet/
54
+ split: val
55
+ img_size: 256
configs/tokenizer_xl.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ trainer:
2
+ target: semanticist.engine.diffusion_trainer.DiffusionTrainer
3
+ params:
4
+ num_epoch: 400
5
+ valid_size: 64
6
+ blr: 2.5e-5
7
+ cosine_lr: True
8
+ warmup_epochs: 100
9
+ batch_size: 256
10
+ num_workers: 16
11
+ pin_memory: True
12
+ grad_accum_steps: 1
13
+ precision: 'bf16'
14
+ max_grad_norm: 3.0
15
+ enable_ema: True
16
+ save_every: 10000
17
+ sample_every: 5000
18
+ fid_every: 50000
19
+ result_folder: "./output/tokenizer/models_xl"
20
+ log_dir: "./output/tokenizer/models_xl/logs"
21
+ cfg: 3.0
22
+ compile: True
23
+ model:
24
+ target: semanticist.stage1.diffuse_slot.DiffuseSlot
25
+ params:
26
+ encoder: 'vit_base_patch16'
27
+ enc_img_size: 256
28
+ enc_causal: True
29
+ enc_use_mlp: False
30
+ num_slots: 256
31
+ slot_dim: 16
32
+ norm_slots: True
33
+ dit_model: 'DiT-XL-2'
34
+ vae: 'xwen99/mar-vae-kl16'
35
+ enable_nest: False
36
+ enable_nest_after: 50
37
+ use_repa: True
38
+ eval_fid: True
39
+ fid_stats: 'fid_stats/adm_in256_stats.npz'
40
+ num_sampling_steps: '250'
41
+ ckpt_path: None
42
+
43
+ dataset:
44
+ target: semanticist.utils.datasets.ImageNet
45
+ params:
46
+ root: ./dataset/imagenet/
47
+ split: train
48
+ img_size: 256
49
+
50
+ test_dataset:
51
+ target: semanticist.utils.datasets.ImageNet
52
+ params:
53
+ root: ./dataset/imagenet/
54
+ split: val
55
+ img_size: 256
examples/city.jpg ADDED
examples/food.jpg ADDED
examples/highland.webp ADDED
gen_demo.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import os.path as osp
5
+ import torch
6
+ import matplotlib.pyplot as plt
7
+ from omegaconf import OmegaConf
8
+ from tqdm import tqdm
9
+ from huggingface_hub import hf_hub_download
10
+ from semanticist.engine.trainer_utils import instantiate_from_config
11
+ from semanticist.stage1.diffuse_slot import DiffuseSlot
12
+ from semanticist.stage2.gpt import GPT_models
13
+ from semanticist.stage2.generate import generate
14
+ from safetensors import safe_open
15
+ from semanticist.utils.datasets import vae_transforms
16
+ from PIL import Image
17
+ from imagenet_classes import imagenet_classes
18
+
19
+ transform = vae_transforms('test')
20
+
21
+
22
+ def norm_ip(img, low, high):
23
+ img.clamp_(min=low, max=high)
24
+ img.sub_(low).div_(max(high - low, 1e-5))
25
+
26
+ def norm_range(t, value_range):
27
+ if value_range is not None:
28
+ norm_ip(t, value_range[0], value_range[1])
29
+ else:
30
+ norm_ip(t, float(t.min()), float(t.max()))
31
+
32
+ from PIL import Image
33
+ def convert_np(img):
34
+ ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\
35
+ .permute(1, 2, 0).to("cpu", torch.uint8).numpy()
36
+ return ndarr
37
+ def convert_PIL(img):
38
+ ndarr = img.mul(255).add_(0.5).clamp_(0, 255)\
39
+ .permute(1, 2, 0).to("cpu", torch.uint8).numpy()
40
+ img = Image.fromarray(ndarr)
41
+ return img
42
+
43
+ def norm_slots(slots):
44
+ mean = torch.mean(slots, dim=-1, keepdim=True)
45
+ std = torch.std(slots, dim=-1, keepdim=True)
46
+ return (slots - mean) / std
47
+
48
+ def load_state_dict(state_dict, model):
49
+ """Helper to load a state dict with proper prefix handling."""
50
+ if 'state_dict' in state_dict:
51
+ state_dict = state_dict['state_dict']
52
+ # Remove '_orig_mod' prefix if present
53
+ state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
54
+ missing, unexpected = model.load_state_dict(
55
+ state_dict, strict=False
56
+ )
57
+ # print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}")
58
+
59
+ def load_safetensors(path, model):
60
+ """Helper to load a safetensors checkpoint."""
61
+ from safetensors.torch import safe_open
62
+ with safe_open(path, framework="pt", device="cpu") as f:
63
+ state_dict = {k: f.get_tensor(k) for k in f.keys()}
64
+ load_state_dict(state_dict, model)
65
+
66
+ def load_checkpoint(ckpt_path, model):
67
+ if ckpt_path is None or not osp.exists(ckpt_path):
68
+ return
69
+
70
+ if osp.isdir(ckpt_path):
71
+ # ckpt_path is something like 'path/to/models/step10/'
72
+ model_path = osp.join(ckpt_path, "model.safetensors")
73
+ if osp.exists(model_path):
74
+ load_safetensors(model_path, model)
75
+ else:
76
+ # ckpt_path is something like 'path/to/models/step10.pt'
77
+ if ckpt_path.endswith(".safetensors"):
78
+ load_safetensors(ckpt_path, model)
79
+ else:
80
+ state_dict = torch.load(ckpt_path, map_location="cpu")
81
+ load_state_dict(state_dict, model)
82
+
83
+ print(f"Loaded checkpoint from {ckpt_path}")
84
+
85
+ device = "cuda" if torch.cuda.is_available() else "cpu"
86
+ print(f"Is CUDA available: {torch.cuda.is_available()}")
87
+ if device == 'cuda':
88
+ print(f"CUDA device: {torch.cuda.get_device_name(torch.cuda.current_device())}")
89
+
90
+ ckpt_path = hf_hub_download(repo_id='tennant/semanticist', filename="semanticist_ar_gen_L.pkl", cache_dir='/mnt/ceph_rbd/mnt_pvc_vid_data/zbc/cache/')
91
+ config_path = 'configs/autoregressive_xl.yaml'
92
+
93
+ cfg = OmegaConf.load(config_path)
94
+ params = cfg.trainer.params
95
+
96
+ ae_model = instantiate_from_config(params.ae_model).to(device)
97
+ ae_model_path = hf_hub_download(repo_id='tennant/semanticist', filename="semanticist_tok_XL.pkl", cache_dir='/mnt/ceph_rbd/mnt_pvc_vid_data/zbc/cache/')
98
+ load_checkpoint(ae_model_path, ae_model)
99
+ ae_model.eval()
100
+
101
+ gpt_model = GPT_models[params.gpt_model.target](**params.gpt_model.params).to(device)
102
+ load_checkpoint(ckpt_path, gpt_model)
103
+ gpt_model.eval();
104
+
105
+ def viz_diff_slots(model, slots, nums, cfg=1.0, return_figs=False):
106
+ n_slots_inf = []
107
+ for num_slots_to_inference in nums:
108
+ drop_mask = model.nested_sampler(slots.shape[0], device, num_slots_to_inference)
109
+ recon_n = model.sample(slots, drop_mask=drop_mask, cfg=cfg)
110
+ n_slots_inf.append(recon_n)
111
+ return [convert_np(n_slots_inf[i][0]) for i in range(len(n_slots_inf))]
112
+
113
+ num_slots = params.ae_model.params.num_slots
114
+ slot_dim = params.ae_model.params.slot_dim
115
+ dtype = torch.bfloat16
116
+ # the model is trained with only 32 tokens.
117
+ num_slots_to_gen = 32
118
+
119
+ # Function to generate image from class
120
+ def generate_from_class(class_id, cfg_scale):
121
+ with torch.no_grad():
122
+ dtype = torch.bfloat16
123
+ num_slots_to_gen = 32
124
+ with torch.autocast(device, dtype=dtype):
125
+ slots_gen = generate(
126
+ gpt_model,
127
+ torch.tensor([class_id]).to(device),
128
+ num_slots_to_gen,
129
+ cfg_scale=cfg_scale,
130
+ cfg_schedule="linear"
131
+ )
132
+ if num_slots_to_gen < num_slots:
133
+ null_slots = ae_model.dit.null_cond.expand(slots_gen.shape[0], -1, -1)
134
+ null_slots = null_slots[:, num_slots_to_gen:, :]
135
+ slots_gen = torch.cat([slots_gen, null_slots], dim=1)
136
+ return slots_gen
137
+
138
+ with gr.Blocks() as demo:
139
+ with gr.Row():
140
+ # First column - Input and configs
141
+ with gr.Column(scale=1):
142
+ gr.Markdown("## Input")
143
+
144
+ # Replace image input with ImageNet class selection
145
+ imagenet_classes = {k: v for k, v in enumerate(imagenet_classes)}
146
+ class_choices = [f"{id}: {name}" for id, name in imagenet_classes.items()]
147
+
148
+ # Dropdown for class selection
149
+ class_dropdown = gr.Dropdown(
150
+ choices=class_choices[:20], # Limit for demonstration
151
+ label="Select ImageNet Class",
152
+ value=class_choices[0] if class_choices else None
153
+ )
154
+
155
+ # Option to enter class ID directly
156
+ class_id_input = gr.Number(
157
+ label="Or enter class ID directly (0-999)",
158
+ value=0,
159
+ minimum=0,
160
+ maximum=999,
161
+ step=1
162
+ )
163
+
164
+ with gr.Group():
165
+ gr.Markdown("### Configuration")
166
+ show_gallery = gr.Checkbox(label="Show Gallery", value=True)
167
+ slider = gr.Slider(minimum=0.1, maximum=20.0, value=4.0, label="CFG value")
168
+ labels_input = gr.Textbox(
169
+ label="Number of tokens to reconstruct (comma-separated)",
170
+ value="1, 2, 4, 8, 16",
171
+ placeholder="Enter comma-separated numbers for the number of slots to use"
172
+ )
173
+
174
+ # Second column - Output (conditionally rendered)
175
+ with gr.Column(scale=1):
176
+ gr.Markdown("## Output")
177
+
178
+ # Container for conditional rendering
179
+ with gr.Group(visible=True) as gallery_container:
180
+ gallery = gr.Gallery(label="Result Gallery", columns=3, height="auto", show_label=True)
181
+
182
+ # Always visible output image
183
+ output_image = gr.Image(label="Generated Image", type="numpy")
184
+
185
+ # Handle form submission
186
+ submit_btn = gr.Button("Generate")
187
+
188
+ # Define the processing logic
189
+ def update_outputs(class_selection, class_id, show_gallery_value, slider_value, labels_text):
190
+ # Determine which class to use - either from dropdown or direct input
191
+ if class_selection:
192
+ # Extract class ID from the dropdown selection
193
+ selected_class_id = int(class_selection.split(":")[0])
194
+ else:
195
+ selected_class_id = int(class_id)
196
+
197
+ # Update the visibility of the gallery container
198
+ gallery_container.visible = show_gallery_value
199
+
200
+ try:
201
+ # Parse the labels from the text input
202
+ if labels_text and "," in labels_text:
203
+ labels = [int(label.strip()) for label in labels_text.split(",")]
204
+ else:
205
+ # Default labels if none provided or in wrong format
206
+ labels = [1, 4, 16, 64, 256]
207
+ except:
208
+ labels = [1, 4, 16, 64, 256]
209
+
210
+ while len(labels) < 3:
211
+ labels.append(256)
212
+
213
+ # Generate the image based on the selected class
214
+ slots_gen = generate_from_class(selected_class_id, cfg_scale=slider_value)
215
+
216
+ recon = viz_diff_slots(ae_model, slots_gen, [32], cfg=slider_value)[0]
217
+
218
+ # Always generate the model decomposition for potential gallery display
219
+ model_decompose = viz_diff_slots(ae_model, slots_gen, labels, cfg=slider_value)
220
+
221
+ if not show_gallery_value:
222
+ # If only the image should be shown, return just the processed image
223
+ return gallery_container, [], recon
224
+ else:
225
+ # Create image variations and pair them with labels
226
+ gallery_images = [
227
+ (recon, f'Generated from class {selected_class_id}'),
228
+ ] + [(img, 'Gen. with ' + str(label) + ' tokens') for img, label in zip(model_decompose, labels)]
229
+ return gallery_container, gallery_images, recon
230
+
231
+ # Connect the inputs and outputs
232
+ submit_btn.click(
233
+ fn=update_outputs,
234
+ inputs=[class_dropdown, class_id_input, show_gallery, slider, labels_input],
235
+ outputs=[gallery_container, gallery, output_image]
236
+ )
237
+
238
+ # Also update when checkbox changes
239
+ show_gallery.change(
240
+ fn=lambda value: gr.update(visible=value),
241
+ inputs=[show_gallery],
242
+ outputs=[gallery_container]
243
+ )
244
+
245
+ # Add examples
246
+ examples = [
247
+ # ["0: tench, Tinca tinca", 0, True, 4.0, "1,2,4,8,16"],
248
+ ["1: goldfish, Carassius auratus", 1, True, 4.0, "1,2,4,8,16"],
249
+ # ["2: great white shark, white shark", 2, True, 4.0, "1,2,4,8,16"],
250
+ ]
251
+
252
+ gr.Examples(
253
+ examples=examples,
254
+ inputs=[class_dropdown, class_id_input, show_gallery, slider, labels_input],
255
+ outputs=[gallery_container, gallery, output_image],
256
+ fn=update_outputs,
257
+ cache_examples=False
258
+ )
259
+
260
+ # Launch the demo
261
+ if __name__ == "__main__":
262
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy==1.26.4
2
+ accelerate
3
+ diffusers[torch]
4
+ transformers
5
+ safetensors
6
+ omegaconf
7
+ tensorboard
8
+ huggingface-hub
9
+ einops
10
+ timm
11
+ scipy
12
+ scikit-learn
13
+ scikit-image
14
+ git+https://github.com/LTH14/torch-fidelity.git@master#egg=torch-fidelity
15
+ opencv-python-headless
16
+ torchmetrics
semanticist/engine/diffusion_trainer.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch
2
+ import os.path as osp
3
+ import shutil
4
+ from tqdm.auto import tqdm
5
+ from einops import rearrange
6
+ from accelerate import Accelerator
7
+ from torchvision.utils import make_grid, save_image
8
+ from torch.utils.data import DataLoader, random_split, DistributedSampler
9
+ from semanticist.utils.logger import SmoothedValue, MetricLogger, empty_cache
10
+ from accelerate.utils import DistributedDataParallelKwargs
11
+ from torchmetrics.functional.image import (
12
+ peak_signal_noise_ratio as psnr,
13
+ structural_similarity_index_measure as ssim
14
+ )
15
+ from semanticist.engine.trainer_utils import (
16
+ instantiate_from_config, concat_all_gather,
17
+ save_img_batch, get_fid_stats,
18
+ EMAModel, PaddedDataset, create_scheduler, load_state_dict,
19
+ load_safetensors, setup_result_folders, create_optimizer
20
+ )
21
+
22
+ class DiffusionTrainer:
23
+ def __init__(
24
+ self,
25
+ model,
26
+ dataset,
27
+ test_dataset=None,
28
+ test_only=False,
29
+ num_epoch=400,
30
+ valid_size=32,
31
+ blr=1e-4,
32
+ cosine_lr=True,
33
+ lr_min=0,
34
+ warmup_epochs=100,
35
+ warmup_steps=None,
36
+ warmup_lr_init=0,
37
+ decay_steps=None,
38
+ batch_size=32,
39
+ eval_bs=32,
40
+ test_bs=64,
41
+ num_workers=8,
42
+ pin_memory=False,
43
+ max_grad_norm=None,
44
+ grad_accum_steps=1,
45
+ precision='bf16',
46
+ save_every=10000,
47
+ sample_every=1000,
48
+ fid_every=50000,
49
+ result_folder=None,
50
+ log_dir="./log",
51
+ cfg=3.0,
52
+ test_num_slots=None,
53
+ eval_fid=False,
54
+ fid_stats=None,
55
+ enable_ema=False,
56
+ compile=False,
57
+ ):
58
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
59
+ self.accelerator = Accelerator(
60
+ kwargs_handlers=[kwargs],
61
+ mixed_precision=precision,
62
+ gradient_accumulation_steps=grad_accum_steps,
63
+ log_with="tensorboard",
64
+ project_dir=log_dir,
65
+ )
66
+
67
+ self.model = instantiate_from_config(model)
68
+ self.num_slots = model.params.num_slots
69
+
70
+ if test_dataset is not None:
71
+ test_dataset = instantiate_from_config(test_dataset)
72
+ self.test_ds = test_dataset
73
+
74
+ # Calculate padded dataset size to ensure even distribution
75
+ total_size = len(test_dataset)
76
+ world_size = self.accelerator.num_processes
77
+ padding_size = world_size * test_bs - (total_size % (world_size * test_bs))
78
+ self.test_dataset_size = total_size
79
+
80
+ self.test_ds = PaddedDataset(self.test_ds, padding_size)
81
+ self.test_dl = DataLoader(
82
+ self.test_ds,
83
+ batch_size=test_bs,
84
+ num_workers=num_workers,
85
+ pin_memory=pin_memory,
86
+ shuffle=False,
87
+ drop_last=True,
88
+ )
89
+ if self.accelerator.is_main_process:
90
+ print(f"test dataset size: {len(test_dataset)}, test batch size: {test_bs}")
91
+ else:
92
+ self.test_dl = None
93
+ self.test_only = test_only
94
+
95
+ if not test_only:
96
+ dataset = instantiate_from_config(dataset)
97
+ train_size = len(dataset) - valid_size
98
+ self.train_ds, self.valid_ds = random_split(
99
+ dataset,
100
+ [train_size, valid_size],
101
+ generator=torch.Generator().manual_seed(42),
102
+ )
103
+ if self.accelerator.is_main_process:
104
+ print(f"train dataset size: {train_size}, valid dataset size: {valid_size}")
105
+
106
+ sampler = DistributedSampler(
107
+ self.train_ds,
108
+ num_replicas=self.accelerator.num_processes,
109
+ rank=self.accelerator.process_index,
110
+ shuffle=True,
111
+ )
112
+ self.train_dl = DataLoader(
113
+ self.train_ds,
114
+ batch_size=batch_size,
115
+ sampler=sampler,
116
+ num_workers=num_workers,
117
+ pin_memory=pin_memory,
118
+ drop_last=True,
119
+ )
120
+ self.valid_dl = DataLoader(
121
+ self.valid_ds,
122
+ batch_size=eval_bs,
123
+ shuffle=False,
124
+ num_workers=num_workers,
125
+ pin_memory=pin_memory,
126
+ drop_last=False,
127
+ )
128
+
129
+ effective_bs = batch_size * grad_accum_steps * self.accelerator.num_processes
130
+ lr = blr * effective_bs / 256
131
+ if self.accelerator.is_main_process:
132
+ print(f"Effective batch size is {effective_bs}")
133
+
134
+ self.g_optim = create_optimizer(self.model, weight_decay=0.05, learning_rate=lr,) # accelerator=self.accelerator)
135
+
136
+ if warmup_epochs is not None:
137
+ warmup_steps = warmup_epochs * len(self.train_dl)
138
+
139
+ self.g_sched = create_scheduler(
140
+ self.g_optim,
141
+ num_epoch,
142
+ len(self.train_dl),
143
+ lr_min,
144
+ warmup_steps,
145
+ warmup_lr_init,
146
+ decay_steps,
147
+ cosine_lr
148
+ )
149
+ self.accelerator.register_for_checkpointing(self.g_sched)
150
+ self.model, self.g_optim, self.g_sched = self.accelerator.prepare(self.model, self.g_optim, self.g_sched)
151
+ else:
152
+ self.model, self.test_dl = self.accelerator.prepare(self.model, self.test_dl)
153
+
154
+ self.steps = 0
155
+ self.loaded_steps = -1
156
+
157
+ if compile:
158
+ _model = self.accelerator.unwrap_model(self.model)
159
+ _model.vae = torch.compile(_model.vae, mode="reduce-overhead")
160
+ _model.dit = torch.compile(_model.dit, mode="reduce-overhead")
161
+ # _model.encoder = torch.compile(_model.encoder, mode="reduce-overhead") # nan loss when compiled together with dit, no idea why
162
+ _model.encoder2slot = torch.compile(_model.encoder2slot, mode="reduce-overhead")
163
+
164
+ self.enable_ema = enable_ema
165
+ if self.enable_ema and not self.test_only: # when testing, we directly load the ema dict and skip here
166
+ self.ema_model = EMAModel(self.accelerator.unwrap_model(self.model), self.device)
167
+ self.accelerator.register_for_checkpointing(self.ema_model)
168
+
169
+ self._load_checkpoint(model.params.ckpt_path)
170
+ if self.test_only:
171
+ self.steps = self.loaded_steps
172
+
173
+ self.num_epoch = num_epoch
174
+ self.save_every = save_every
175
+ self.sample_every = sample_every
176
+ self.fid_every = fid_every
177
+ self.max_grad_norm = max_grad_norm
178
+
179
+ self.cfg = cfg
180
+ self.test_num_slots = test_num_slots
181
+ if self.test_num_slots is not None:
182
+ self.test_num_slots = min(self.test_num_slots, self.num_slots)
183
+ else:
184
+ self.test_num_slots = self.num_slots
185
+ eval_fid = eval_fid or model.params.eval_fid # legacy
186
+ self.eval_fid = eval_fid
187
+ if eval_fid:
188
+ if fid_stats is None:
189
+ fid_stats = model.params.fid_stats # legacy
190
+ assert fid_stats is not None
191
+ assert test_dataset is not None
192
+ self.fid_stats = fid_stats
193
+
194
+ self.result_folder = result_folder
195
+ self.model_saved_dir, self.image_saved_dir = setup_result_folders(result_folder)
196
+
197
+ @property
198
+ def device(self):
199
+ return self.accelerator.device
200
+
201
+ def _load_checkpoint(self, ckpt_path=None):
202
+ if ckpt_path is None or not osp.exists(ckpt_path):
203
+ return
204
+
205
+ model = self.accelerator.unwrap_model(self.model)
206
+
207
+ if osp.isdir(ckpt_path):
208
+ # ckpt_path is something like 'path/to/models/step10/'
209
+ self.loaded_steps = int(
210
+ ckpt_path.split("step")[-1].split("/")[0]
211
+ )
212
+ if not self.test_only:
213
+ self.accelerator.load_state(ckpt_path)
214
+ else:
215
+ if self.enable_ema:
216
+ model_path = osp.join(ckpt_path, "custom_checkpoint_1.pkl")
217
+ if osp.exists(model_path):
218
+ state_dict = torch.load(model_path, map_location="cpu")
219
+ load_state_dict(state_dict, model)
220
+ if self.accelerator.is_main_process:
221
+ print(f"Loaded ema model from {model_path}")
222
+ else:
223
+ model_path = osp.join(ckpt_path, "model.safetensors")
224
+ if osp.exists(model_path):
225
+ load_safetensors(model_path, model)
226
+ else:
227
+ # ckpt_path is something like 'path/to/models/step10.pt'
228
+ if ckpt_path.endswith(".safetensors"):
229
+ load_safetensors(ckpt_path, model)
230
+ else:
231
+ state_dict = torch.load(ckpt_path, map_location="cpu")
232
+ load_state_dict(state_dict, model)
233
+ if self.accelerator.is_main_process:
234
+ print(f"Loaded checkpoint from {ckpt_path}")
235
+
236
+ def train(self, config=None):
237
+ n_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
238
+ if self.accelerator.is_main_process:
239
+ print(f"number of learnable parameters: {n_parameters//1e6}M")
240
+ if config is not None:
241
+ # save the config
242
+ from omegaconf import OmegaConf
243
+ if isinstance(config, str) and osp.exists(config):
244
+ # If it's a path, copy the file to config.yaml
245
+ shutil.copy(config, osp.join(self.result_folder, "config.yaml"))
246
+ else:
247
+ # If it's an OmegaConf object, dump it
248
+ config_save_path = osp.join(self.result_folder, "config.yaml")
249
+ OmegaConf.save(config, config_save_path)
250
+
251
+ self.accelerator.init_trackers("semanticist")
252
+
253
+ if self.test_only:
254
+ empty_cache()
255
+ self.evaluate()
256
+ self.accelerator.wait_for_everyone()
257
+ empty_cache()
258
+ return
259
+
260
+ for epoch in range(self.num_epoch):
261
+ if ((epoch + 1) * len(self.train_dl)) <= self.loaded_steps:
262
+ if self.accelerator.is_main_process:
263
+ print(f"Epoch {epoch} is skipped because it is loaded from ckpt")
264
+ self.steps += len(self.train_dl)
265
+ continue
266
+
267
+ if self.steps < self.loaded_steps:
268
+ for _ in self.train_dl:
269
+ self.steps += 1
270
+ if self.steps >= self.loaded_steps:
271
+ break
272
+
273
+
274
+ self.accelerator.unwrap_model(self.model).current_epoch = epoch
275
+ self.model.train() # Set model to training mode
276
+
277
+ logger = MetricLogger(delimiter=" ")
278
+ logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
279
+ header = 'Epoch: [{}/{}]'.format(epoch, self.num_epoch)
280
+ print_freq = 20
281
+ for data_iter_step, batch in enumerate(logger.log_every(self.train_dl, print_freq, header)):
282
+ img, _ = batch
283
+ img = img.to(self.device, non_blocking=True)
284
+ self.steps += 1
285
+
286
+ with self.accelerator.accumulate(self.model):
287
+ with self.accelerator.autocast():
288
+ if self.steps == 1:
289
+ print(f"Training batch size: {img.size(0)}")
290
+ print(f"Hello from index {self.accelerator.local_process_index}")
291
+ losses = self.model(img, epoch=epoch)
292
+ # combine
293
+ loss = sum([v for _, v in losses.items()])
294
+
295
+ self.accelerator.backward(loss)
296
+ if self.accelerator.sync_gradients and self.max_grad_norm is not None:
297
+ self.accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
298
+ self.g_optim.step()
299
+ if self.g_sched is not None:
300
+ self.g_sched.step_update(self.steps)
301
+ self.g_optim.zero_grad()
302
+
303
+ self.accelerator.wait_for_everyone()
304
+
305
+ # update ema with state dict
306
+ if self.enable_ema:
307
+ self.ema_model.update(self.accelerator.unwrap_model(self.model))
308
+
309
+ for key, value in losses.items():
310
+ logger.update(**{key: value.item()})
311
+ logger.update(lr=self.g_optim.param_groups[0]["lr"])
312
+
313
+ if self.steps % self.save_every == 0:
314
+ self.save()
315
+
316
+ if (self.steps % self.sample_every == 0) or (self.steps % self.fid_every == 0):
317
+ empty_cache()
318
+ self.evaluate()
319
+ self.accelerator.wait_for_everyone()
320
+ empty_cache()
321
+
322
+ write_dict = dict(epoch=epoch)
323
+ for key, value in losses.items(): # omitted all_gather here
324
+ write_dict.update(**{key: value.item()})
325
+ write_dict.update(lr=self.g_optim.param_groups[0]["lr"])
326
+ self.accelerator.log(write_dict, step=self.steps)
327
+
328
+ logger.synchronize_between_processes()
329
+ if self.accelerator.is_main_process:
330
+ print("Averaged stats:", logger)
331
+
332
+ self.accelerator.end_training()
333
+ self.save()
334
+ if self.accelerator.is_main_process:
335
+ print("Train finished!")
336
+
337
+ def save(self):
338
+ self.accelerator.wait_for_everyone()
339
+ self.accelerator.save_state(
340
+ os.path.join(self.model_saved_dir, f"step{self.steps}")
341
+ )
342
+
343
+ @torch.no_grad()
344
+ def evaluate(self):
345
+ self.model.eval()
346
+ if not self.test_only:
347
+ with tqdm(
348
+ self.valid_dl,
349
+ dynamic_ncols=True,
350
+ disable=not self.accelerator.is_main_process,
351
+ ) as valid_dl:
352
+ for batch_i, batch in enumerate(valid_dl):
353
+ if isinstance(batch, tuple) or isinstance(batch, list):
354
+ img, targets = batch[0], batch[1]
355
+ else:
356
+ img = batch
357
+
358
+ with self.accelerator.autocast():
359
+ rec = self.model(img, sample=True, inference_with_n_slots=self.test_num_slots, cfg=1.0)
360
+ imgs_and_recs = torch.stack((img.to(rec.device), rec), dim=0)
361
+ imgs_and_recs = rearrange(imgs_and_recs, "r b ... -> (b r) ...")
362
+ imgs_and_recs = imgs_and_recs.detach().cpu().float()
363
+
364
+ grid = make_grid(
365
+ imgs_and_recs, nrow=6, normalize=True, value_range=(0, 1)
366
+ )
367
+ if self.accelerator.is_main_process:
368
+ save_image(
369
+ grid,
370
+ os.path.join(
371
+ self.image_saved_dir, f"step_{self.steps}_slots{self.test_num_slots}_{batch_i}.jpg"
372
+ ),
373
+ )
374
+
375
+ if self.cfg != 1.0:
376
+ with self.accelerator.autocast():
377
+ rec = self.model(img, sample=True, inference_with_n_slots=self.test_num_slots, cfg=self.cfg)
378
+
379
+ imgs_and_recs = torch.stack((img.to(rec.device), rec), dim=0)
380
+ imgs_and_recs = rearrange(imgs_and_recs, "r b ... -> (b r) ...")
381
+ imgs_and_recs = imgs_and_recs.detach().cpu().float()
382
+
383
+ grid = make_grid(
384
+ imgs_and_recs, nrow=6, normalize=True, value_range=(0, 1)
385
+ )
386
+ if self.accelerator.is_main_process:
387
+ save_image(
388
+ grid,
389
+ os.path.join(
390
+ self.image_saved_dir, f"step_{self.steps}_cfg_{self.cfg}_slots{self.test_num_slots}_{batch_i}.jpg"
391
+ ),
392
+ )
393
+ if (self.eval_fid and self.test_dl is not None) and (self.test_only or (self.steps % self.fid_every == 0)):
394
+ real_dir = "./dataset/imagenet/val256"
395
+ rec_dir = os.path.join(self.image_saved_dir, f"rec_step{self.steps}_slots{self.test_num_slots}")
396
+ os.makedirs(rec_dir, exist_ok=True)
397
+
398
+ if self.cfg != 1.0:
399
+ rec_cfg_dir = os.path.join(self.image_saved_dir, f"rec_step{self.steps}_cfg_{self.cfg}_slots{self.test_num_slots}")
400
+ os.makedirs(rec_cfg_dir, exist_ok=True)
401
+
402
+ def process_batch(cfg_value, save_dir, header):
403
+ logger = MetricLogger(delimiter=" ")
404
+ print_freq = 5
405
+ psnr_values = []
406
+ ssim_values = []
407
+ total_processed = 0
408
+
409
+ for batch_i, batch in enumerate(logger.log_every(self.test_dl, print_freq, header)):
410
+ imgs, targets = (batch[0], batch[1]) if isinstance(batch, (tuple, list)) else (batch, None)
411
+
412
+ # Skip processing if we've already processed all real samples
413
+ if total_processed >= self.test_dataset_size:
414
+ break
415
+
416
+ imgs = imgs.to(self.device, non_blocking=True)
417
+ if targets is not None:
418
+ targets = targets.to(self.device, non_blocking=True)
419
+
420
+ with self.accelerator.autocast():
421
+ recs = self.model(imgs, sample=True, inference_with_n_slots=self.test_num_slots, cfg=cfg_value)
422
+
423
+ psnr_val = psnr(recs, imgs, data_range=1.0)
424
+ ssim_val = ssim(recs, imgs, data_range=1.0)
425
+
426
+ recs = concat_all_gather(recs).detach()
427
+ psnr_val = concat_all_gather(psnr_val.view(1))
428
+ ssim_val = concat_all_gather(ssim_val.view(1))
429
+
430
+ # Remove padding after gathering from all GPUs
431
+ samples_in_batch = min(
432
+ recs.size(0), # Always use the gathered size
433
+ self.test_dataset_size - total_processed
434
+ )
435
+ recs = recs[:samples_in_batch]
436
+ psnr_val = psnr_val[:samples_in_batch]
437
+ ssim_val = ssim_val[:samples_in_batch]
438
+ psnr_values.append(psnr_val)
439
+ ssim_values.append(ssim_val)
440
+
441
+ if self.accelerator.is_main_process:
442
+ rec_paths = [os.path.join(save_dir, f"step_{self.steps}_slots{self.test_num_slots}_{batch_i}_{j}_rec_cfg_{cfg_value}_slots{self.test_num_slots}.png")
443
+ for j in range(recs.size(0))]
444
+ save_img_batch(recs.cpu(), rec_paths)
445
+
446
+ total_processed += samples_in_batch
447
+
448
+ self.accelerator.wait_for_everyone()
449
+
450
+ return torch.cat(psnr_values).mean(), torch.cat(ssim_values).mean()
451
+
452
+ # Helper function to calculate and log metrics
453
+ def calculate_and_log_metrics(real_dir, rec_dir, cfg_value, psnr_val, ssim_val):
454
+ if self.accelerator.is_main_process:
455
+ metrics_dict = get_fid_stats(real_dir, rec_dir, self.fid_stats)
456
+ fid = metrics_dict["frechet_inception_distance"]
457
+ inception_score = metrics_dict["inception_score_mean"]
458
+
459
+ metric_prefix = "fid"
460
+ isc_prefix = "isc"
461
+ self.accelerator.log({
462
+ metric_prefix: fid,
463
+ isc_prefix: inception_score,
464
+ f"psnr": psnr_val,
465
+ f"ssim": ssim_val,
466
+ "cfg": cfg_value
467
+ }, step=self.steps)
468
+
469
+ print(f"{'CFG: {cfg_value}'} "
470
+ f"FID: {fid:.2f}, ISC: {inception_score:.2f}, "
471
+ f"PSNR: {psnr_val:.2f}, SSIM: {ssim_val:.4f}")
472
+
473
+ # Process without CFG
474
+ if self.cfg == 1.0 or not self.test_only:
475
+ psnr_val, ssim_val = process_batch(1.0, rec_dir, 'Testing: w/o CFG')
476
+ calculate_and_log_metrics(real_dir, rec_dir, 1.0, psnr_val, ssim_val)
477
+
478
+ # Process with CFG if needed
479
+ if self.cfg != 1.0:
480
+ psnr_val, ssim_val = process_batch(self.cfg, rec_cfg_dir, 'Testing: w/ CFG')
481
+ calculate_and_log_metrics(real_dir, rec_cfg_dir, self.cfg, psnr_val, ssim_val)
482
+
483
+ # Cleanup
484
+ if self.accelerator.is_main_process:
485
+ shutil.rmtree(rec_dir)
486
+ if self.cfg != 1.0:
487
+ shutil.rmtree(rec_cfg_dir)
488
+ self.model.train()
semanticist/engine/gpt_trainer.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch
2
+ import os.path as osp
3
+ import shutil
4
+ import numpy as np
5
+ import copy
6
+ import torch.nn as nn
7
+ from tqdm.auto import tqdm
8
+ from accelerate import Accelerator
9
+ from torchvision.utils import make_grid, save_image
10
+ from torch.utils.data import DataLoader, DistributedSampler
11
+ from semanticist.utils.logger import SmoothedValue, MetricLogger, empty_cache
12
+ from accelerate.utils import DistributedDataParallelKwargs
13
+ from semanticist.stage2.gpt import GPT_models
14
+ from semanticist.stage2.generate import generate
15
+ from pathlib import Path
16
+ import time
17
+
18
+ from semanticist.engine.trainer_utils import (
19
+ instantiate_from_config, concat_all_gather,
20
+ save_img_batch, get_fid_stats,
21
+ EMAModel, create_scheduler, load_state_dict, load_safetensors,
22
+ setup_result_folders, create_optimizer,
23
+ CacheDataLoader
24
+ )
25
+
26
+ class GPTTrainer(nn.Module):
27
+ def __init__(
28
+ self,
29
+ ae_model,
30
+ gpt_model,
31
+ dataset,
32
+ test_only=False,
33
+ num_test_images=50000,
34
+ num_epoch=400,
35
+ eval_classes=[1, 7, 282, 604, 724, 207, 250, 751, 404, 850], # goldfish, cock, tiger cat, hourglass, ship, golden retriever, husky, race car, airliner, teddy bear
36
+ blr=1e-4,
37
+ cosine_lr=False,
38
+ lr_min=0,
39
+ warmup_epochs=100,
40
+ warmup_steps=None,
41
+ warmup_lr_init=0,
42
+ decay_steps=None,
43
+ batch_size=32,
44
+ cache_bs=8,
45
+ test_bs=100,
46
+ num_workers=8,
47
+ pin_memory=False,
48
+ max_grad_norm=None,
49
+ grad_accum_steps=1,
50
+ precision='bf16',
51
+ save_every=10000,
52
+ sample_every=1000,
53
+ fid_every=50000,
54
+ result_folder=None,
55
+ log_dir="./log",
56
+ ae_cfg=1.0,
57
+ cfg=6.0,
58
+ cfg_schedule="linear",
59
+ temperature=1.0,
60
+ train_num_slots=None,
61
+ test_num_slots=None,
62
+ eval_fid=False,
63
+ fid_stats=None,
64
+ enable_ema=False,
65
+ compile=False,
66
+ enable_cache_latents=True,
67
+ cache_dir='/dev/shm/slot_cache'
68
+ ):
69
+ super().__init__()
70
+ kwargs = DistributedDataParallelKwargs(find_unused_parameters=False)
71
+ self.accelerator = Accelerator(
72
+ kwargs_handlers=[kwargs],
73
+ mixed_precision=precision,
74
+ gradient_accumulation_steps=grad_accum_steps,
75
+ log_with="tensorboard",
76
+ project_dir=log_dir,
77
+ )
78
+
79
+ self.ae_model = instantiate_from_config(ae_model)
80
+ ae_model_path = ae_model.params.ckpt_path
81
+ assert ae_model_path.endswith(".safetensors") or ae_model_path.endswith(".pt") or ae_model_path.endswith(".pth") or ae_model_path.endswith(".pkl")
82
+ assert osp.exists(ae_model_path), f"AE model checkpoint {ae_model_path} does not exist"
83
+ self._load_checkpoint(ae_model_path, self.ae_model)
84
+
85
+ self.ae_model.to(self.device)
86
+ for param in self.ae_model.parameters():
87
+ param.requires_grad = False
88
+ self.ae_model.eval()
89
+
90
+ self.model_name = gpt_model.target
91
+ if 'GPT' in gpt_model.target:
92
+ self.gpt_model = GPT_models[gpt_model.target](**gpt_model.params)
93
+ else:
94
+ raise ValueError(f"Unknown model type: {gpt_model.target}")
95
+ self.num_slots = ae_model.params.num_slots
96
+ self.slot_dim = ae_model.params.slot_dim
97
+
98
+ self.test_only = test_only
99
+ self.test_bs = test_bs
100
+ self.num_test_images = num_test_images
101
+ self.num_classes = gpt_model.params.num_classes
102
+ self.batch_size = batch_size
103
+ if not test_only:
104
+ self.train_ds = instantiate_from_config(dataset)
105
+ train_size = len(self.train_ds)
106
+ if self.accelerator.is_main_process:
107
+ print(f"train dataset size: {train_size}")
108
+
109
+ sampler = DistributedSampler(
110
+ self.train_ds,
111
+ num_replicas=self.accelerator.num_processes,
112
+ rank=self.accelerator.process_index,
113
+ shuffle=True,
114
+ )
115
+ self.train_dl = DataLoader(
116
+ self.train_ds,
117
+ batch_size=batch_size if not enable_cache_latents else cache_bs,
118
+ sampler=sampler,
119
+ num_workers=num_workers,
120
+ pin_memory=pin_memory,
121
+ drop_last=True,
122
+ )
123
+
124
+ effective_bs = batch_size * grad_accum_steps * self.accelerator.num_processes
125
+ lr = blr * effective_bs / 256
126
+ if self.accelerator.is_main_process:
127
+ print(f"Effective batch size is {effective_bs}")
128
+
129
+ self.g_optim = create_optimizer(self.gpt_model, weight_decay=0.05, learning_rate=lr)
130
+
131
+ if warmup_epochs is not None:
132
+ warmup_steps = warmup_epochs * len(self.train_dl)
133
+
134
+ self.g_sched = create_scheduler(
135
+ self.g_optim,
136
+ num_epoch,
137
+ len(self.train_dl),
138
+ lr_min,
139
+ warmup_steps,
140
+ warmup_lr_init,
141
+ decay_steps,
142
+ cosine_lr
143
+ )
144
+ self.accelerator.register_for_checkpointing(self.g_sched)
145
+ self.gpt_model, self.g_optim, self.g_sched = self.accelerator.prepare(self.gpt_model, self.g_optim, self.g_sched)
146
+ else:
147
+ self.gpt_model = self.accelerator.prepare(self.gpt_model)
148
+
149
+ self.steps = 0
150
+ self.loaded_steps = -1
151
+
152
+ if compile:
153
+ self.ae_model = torch.compile(self.ae_model, mode="reduce-overhead")
154
+ _model = self.accelerator.unwrap_model(self.gpt_model)
155
+ _model = torch.compile(_model, mode="reduce-overhead")
156
+
157
+ self.enable_ema = enable_ema
158
+ if self.enable_ema and not self.test_only: # when testing, we directly load the ema dict and skip here
159
+ self.ema_model = EMAModel(self.accelerator.unwrap_model(self.gpt_model), self.device)
160
+ self.accelerator.register_for_checkpointing(self.ema_model)
161
+
162
+ self._load_checkpoint(gpt_model.params.ckpt_path)
163
+ if self.test_only:
164
+ self.steps = self.loaded_steps
165
+
166
+ self.num_epoch = num_epoch
167
+ self.save_every = save_every
168
+ self.sample_every = sample_every
169
+ self.fid_every = fid_every
170
+ self.max_grad_norm = max_grad_norm
171
+ self.eval_classes = eval_classes
172
+ self.cfg = cfg
173
+ self.ae_cfg = ae_cfg
174
+ self.cfg_schedule = cfg_schedule
175
+ self.temperature = temperature
176
+ self.train_num_slots = train_num_slots
177
+ self.test_num_slots = test_num_slots
178
+ if self.train_num_slots is not None:
179
+ self.train_num_slots = min(self.train_num_slots, self.num_slots)
180
+ else:
181
+ self.train_num_slots = self.num_slots
182
+ if self.test_num_slots is not None:
183
+ self.num_slots_to_gen = min(self.test_num_slots, self.train_num_slots)
184
+ else:
185
+ self.num_slots_to_gen = self.train_num_slots
186
+ self.eval_fid = eval_fid
187
+ if eval_fid:
188
+ assert fid_stats is not None
189
+ self.fid_stats = fid_stats
190
+
191
+ # Setup result folders
192
+ self.result_folder = result_folder
193
+ self.model_saved_dir, self.image_saved_dir = setup_result_folders(result_folder)
194
+
195
+ # Setup cache
196
+ self.cache_dir = Path(cache_dir)
197
+ self.enable_cache_latents = enable_cache_latents
198
+ self.cache_loader = None
199
+
200
+ @property
201
+ def device(self):
202
+ return self.accelerator.device
203
+
204
+ def _load_checkpoint(self, ckpt_path=None, model=None):
205
+ if ckpt_path is None or not osp.exists(ckpt_path):
206
+ return
207
+
208
+ if model is None:
209
+ model = self.accelerator.unwrap_model(self.gpt_model)
210
+
211
+ if osp.isdir(ckpt_path):
212
+ self.loaded_steps = int(
213
+ ckpt_path.split("step")[-1].split("/")[0]
214
+ )
215
+ if not self.test_only:
216
+ self.accelerator.load_state(ckpt_path)
217
+ else:
218
+ if self.enable_ema:
219
+ model_path = osp.join(ckpt_path, "custom_checkpoint_1.pkl")
220
+ if osp.exists(model_path):
221
+ state_dict = torch.load(model_path, map_location="cpu")
222
+ load_state_dict(state_dict, model)
223
+ if self.accelerator.is_main_process:
224
+ print(f"Loaded ema model from {model_path}")
225
+ else:
226
+ model_path = osp.join(ckpt_path, "model.safetensors")
227
+ if osp.exists(model_path):
228
+ load_safetensors(model_path, model)
229
+ else:
230
+ if ckpt_path.endswith(".safetensors"):
231
+ load_safetensors(ckpt_path, model)
232
+ else:
233
+ state_dict = torch.load(ckpt_path, map_location="cpu")
234
+ load_state_dict(state_dict, model)
235
+ if self.accelerator.is_main_process:
236
+ print(f"Loaded checkpoint from {ckpt_path}")
237
+
238
+ def _build_cache(self):
239
+ """Build cache for slots and targets."""
240
+ rank = self.accelerator.process_index
241
+ world_size = self.accelerator.num_processes
242
+
243
+ # Clean up any existing cache files first
244
+ slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap"
245
+ targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap"
246
+
247
+ if slots_file.exists():
248
+ os.remove(slots_file)
249
+ if targets_file.exists():
250
+ os.remove(targets_file)
251
+
252
+ dataset_size = len(self.train_dl.dataset)
253
+ shard_size = dataset_size // world_size
254
+
255
+ # Detect number of augmentations from first batch
256
+ with torch.no_grad():
257
+ sample_batch = next(iter(self.train_dl))
258
+ img, _ = sample_batch
259
+ num_augs = img.shape[1] if len(img.shape) == 5 else 1
260
+
261
+ print(f"Rank {rank}: Creating new cache with {num_augs} augmentations per image...")
262
+ os.makedirs(self.cache_dir, exist_ok=True)
263
+ slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap"
264
+ targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap"
265
+
266
+ # Create memory-mapped files
267
+ slots_mmap = np.memmap(
268
+ slots_file,
269
+ dtype='float32',
270
+ mode='w+',
271
+ shape=(shard_size * num_augs, self.train_num_slots, self.slot_dim)
272
+ )
273
+
274
+ targets_mmap = np.memmap(
275
+ targets_file,
276
+ dtype='int64',
277
+ mode='w+',
278
+ shape=(shard_size * num_augs,)
279
+ )
280
+
281
+ # Cache data
282
+ with torch.no_grad():
283
+ for i, batch in enumerate(tqdm(
284
+ self.train_dl,
285
+ desc=f"Rank {rank}: Caching data",
286
+ disable=not self.accelerator.is_local_main_process
287
+ )):
288
+ imgs, targets = batch
289
+ if len(imgs.shape) == 5: # [B, num_augs, C, H, W]
290
+ B, A, C, H, W = imgs.shape
291
+ imgs = imgs.view(-1, C, H, W) # [B*num_augs, C, H, W]
292
+ targets = targets.unsqueeze(1).expand(-1, A).reshape(-1) # [B*num_augs]
293
+
294
+ # Split imgs into n chunks
295
+ num_splits = num_augs
296
+ split_size = imgs.shape[0] // num_splits
297
+ imgs_splits = torch.split(imgs, split_size)
298
+ targets_splits = torch.split(targets, split_size)
299
+
300
+ start_idx = i * self.train_dl.batch_size * num_augs
301
+
302
+ for split_idx, (img_split, targets_split) in enumerate(zip(imgs_splits, targets_splits)):
303
+ img_split = img_split.to(self.device, non_blocking=True)
304
+ slots_split = self.ae_model.encode_slots(img_split)[:, :self.train_num_slots, :]
305
+
306
+ split_start = start_idx + (split_idx * split_size)
307
+ split_end = split_start + img_split.shape[0]
308
+
309
+ # Write directly to mmap files
310
+ slots_mmap[split_start:split_end] = slots_split.cpu().numpy()
311
+ targets_mmap[split_start:split_end] = targets_split.numpy()
312
+
313
+ # Close the mmap files
314
+ del slots_mmap
315
+ del targets_mmap
316
+
317
+ # Reopen in read mode
318
+ self.cached_latents = np.memmap(
319
+ slots_file,
320
+ dtype='float32',
321
+ mode='r',
322
+ shape=(shard_size * num_augs, self.train_num_slots, self.slot_dim)
323
+ )
324
+
325
+ self.cached_targets = np.memmap(
326
+ targets_file,
327
+ dtype='int64',
328
+ mode='r',
329
+ shape=(shard_size * num_augs,)
330
+ )
331
+
332
+ # Store the number of augmentations for the cache loader
333
+ self.num_augs = num_augs
334
+
335
+ def _setup_cache(self):
336
+ """Setup cache if enabled."""
337
+ self._build_cache()
338
+ self.accelerator.wait_for_everyone()
339
+
340
+ # Initialize cache loader if cache exists
341
+ if self.cached_latents is not None:
342
+ self.cache_loader = CacheDataLoader(
343
+ slots=self.cached_latents,
344
+ targets=self.cached_targets,
345
+ batch_size=self.batch_size,
346
+ num_augs=self.num_augs,
347
+ seed=42 + self.accelerator.process_index
348
+ )
349
+
350
+ def __del__(self):
351
+ """Cleanup cache files."""
352
+ if self.enable_cache_latents:
353
+ rank = self.accelerator.process_index
354
+ world_size = self.accelerator.num_processes
355
+
356
+ # Clean up slots cache
357
+ slots_file = self.cache_dir / f"slots_rank{rank}_of_{world_size}.mmap"
358
+ if slots_file.exists():
359
+ os.remove(slots_file)
360
+
361
+ # Clean up targets cache
362
+ targets_file = self.cache_dir / f"targets_rank{rank}_of_{world_size}.mmap"
363
+ if targets_file.exists():
364
+ os.remove(targets_file)
365
+
366
+ def _train_step(self, slots, targets=None):
367
+ """Execute single training step."""
368
+
369
+ with self.accelerator.accumulate(self.gpt_model):
370
+ with self.accelerator.autocast():
371
+ loss = self.gpt_model(slots, targets)
372
+
373
+ self.accelerator.backward(loss)
374
+ if self.accelerator.sync_gradients and self.max_grad_norm is not None:
375
+ self.accelerator.clip_grad_norm_(self.gpt_model.parameters(), self.max_grad_norm)
376
+ self.g_optim.step()
377
+ if self.g_sched is not None:
378
+ self.g_sched.step_update(self.steps)
379
+ self.g_optim.zero_grad()
380
+
381
+ # Update EMA model if enabled
382
+ if self.enable_ema:
383
+ self.ema_model.update(self.accelerator.unwrap_model(self.gpt_model))
384
+
385
+ return loss
386
+
387
+ def _train_epoch_cached(self, epoch, logger):
388
+ """Train one epoch using cached data."""
389
+ self.cache_loader.set_epoch(epoch)
390
+ header = f'Epoch: [{epoch}/{self.num_epoch}]'
391
+
392
+ for batch in logger.log_every(self.cache_loader, 20, header):
393
+ slots, targets = (b.to(self.device, non_blocking=True) for b in batch)
394
+
395
+ self.steps += 1
396
+
397
+ if self.steps == 1:
398
+ print(f"Training batch size: {len(slots)}")
399
+ print(f"Hello from index {self.accelerator.local_process_index}")
400
+
401
+ loss = self._train_step(slots, targets)
402
+ self._handle_periodic_ops(loss, logger)
403
+
404
+ def _train_epoch_uncached(self, epoch, logger):
405
+ """Train one epoch using raw data."""
406
+ header = f'Epoch: [{epoch}/{self.num_epoch}]'
407
+
408
+ for batch in logger.log_every(self.train_dl, 20, header):
409
+ img, targets = (b.to(self.device, non_blocking=True) for b in batch)
410
+
411
+ self.steps += 1
412
+
413
+ if self.steps == 1:
414
+ print(f"Training batch size: {img.size(0)}")
415
+ print(f"Hello from index {self.accelerator.local_process_index}")
416
+
417
+ slots = self.ae_model.encode_slots(img)[:, :self.train_num_slots, :]
418
+ loss = self._train_step(slots, targets)
419
+ self._handle_periodic_ops(loss, logger)
420
+
421
+ def _handle_periodic_ops(self, loss, logger):
422
+ """Handle periodic operations and logging."""
423
+ logger.update(loss=loss.item())
424
+ logger.update(lr=self.g_optim.param_groups[0]["lr"])
425
+
426
+ if self.steps % self.save_every == 0:
427
+ self.save()
428
+
429
+ if (self.steps % self.sample_every == 0) or (self.eval_fid and self.steps % self.fid_every == 0):
430
+ empty_cache()
431
+ self.evaluate()
432
+ self.accelerator.wait_for_everyone()
433
+ empty_cache()
434
+
435
+ def _save_config(self, config):
436
+ """Save configuration file."""
437
+ if config is not None and self.accelerator.is_main_process:
438
+ import shutil
439
+ from omegaconf import OmegaConf
440
+
441
+ if isinstance(config, str) and osp.exists(config):
442
+ shutil.copy(config, osp.join(self.result_folder, "config.yaml"))
443
+ else:
444
+ config_save_path = osp.join(self.result_folder, "config.yaml")
445
+ OmegaConf.save(config, config_save_path)
446
+
447
+ def _should_skip_epoch(self, epoch):
448
+ """Check if epoch should be skipped due to loaded checkpoint."""
449
+ loader = self.train_dl if not self.enable_cache_latents else self.cache_loader
450
+ if ((epoch + 1) * len(loader)) <= self.loaded_steps:
451
+ if self.accelerator.is_main_process:
452
+ print(f"Epoch {epoch} is skipped because it is loaded from ckpt")
453
+ self.steps += len(loader)
454
+ return True
455
+
456
+ if self.steps < self.loaded_steps:
457
+ for _ in loader:
458
+ self.steps += 1
459
+ if self.steps >= self.loaded_steps:
460
+ break
461
+ return False
462
+
463
+ def train(self, config=None):
464
+ """Main training loop."""
465
+ # Initial setup
466
+ n_parameters = sum(p.numel() for p in self.parameters() if p.requires_grad)
467
+ if self.accelerator.is_main_process:
468
+ print(f"number of learnable parameters: {n_parameters//1e6}M")
469
+
470
+ self._save_config(config)
471
+ self.accelerator.init_trackers("gpt")
472
+
473
+ # Handle test-only mode
474
+ if self.test_only:
475
+ empty_cache()
476
+ self.evaluate()
477
+ self.accelerator.wait_for_everyone()
478
+ empty_cache()
479
+ return
480
+
481
+ # Setup cache if enabled
482
+ if self.enable_cache_latents:
483
+ self._setup_cache()
484
+
485
+ # Training loop
486
+ for epoch in range(self.num_epoch):
487
+ if self._should_skip_epoch(epoch):
488
+ continue
489
+
490
+ self.gpt_model.train()
491
+ logger = MetricLogger(delimiter=" ")
492
+ logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
493
+
494
+ # Choose training path based on cache availability
495
+ if self.enable_cache_latents:
496
+ self._train_epoch_cached(epoch, logger)
497
+ else:
498
+ self._train_epoch_uncached(epoch, logger)
499
+
500
+ # Synchronize and log epoch stats
501
+ logger.synchronize_between_processes()
502
+ if self.accelerator.is_main_process:
503
+ print("Averaged stats:", logger)
504
+
505
+ # Finish training
506
+ self.accelerator.end_training()
507
+ self.save()
508
+ if self.accelerator.is_main_process:
509
+ print("Train finished!")
510
+
511
+ def save(self):
512
+ self.accelerator.wait_for_everyone()
513
+ self.accelerator.save_state(
514
+ os.path.join(self.model_saved_dir, f"step{self.steps}")
515
+ )
516
+
517
+ @torch.no_grad()
518
+ def evaluate(self, use_ema=True):
519
+ self.gpt_model.eval()
520
+ unwraped_gpt_model = self.accelerator.unwrap_model(self.gpt_model)
521
+ # switch to ema params, only when eval_fid is True
522
+ # if test_only, we directly load the ema dict and skip here
523
+ use_ema = use_ema and self.enable_ema and self.eval_fid and not self.test_only
524
+ if use_ema:
525
+ if hasattr(self, "ema_model"):
526
+ model_without_ddp = self.accelerator.unwrap_model(self.gpt_model)
527
+ model_state_dict = copy.deepcopy(model_without_ddp.state_dict())
528
+ ema_state_dict = copy.deepcopy(model_without_ddp.state_dict())
529
+ for i, (name, _value) in enumerate(model_without_ddp.named_parameters()):
530
+ if "nested_sampler" in name:
531
+ continue
532
+ ema_state_dict[name] = self.ema_model.state_dict()[name]
533
+ if self.accelerator.is_main_process:
534
+ print("Switch to ema")
535
+ model_without_ddp.load_state_dict(ema_state_dict)
536
+ else:
537
+ print("EMA model not found, using original model")
538
+ use_ema = False
539
+
540
+ if not self.test_only:
541
+ classes = torch.tensor(self.eval_classes, device=self.device)
542
+ with self.accelerator.autocast():
543
+ slots = generate(unwraped_gpt_model, classes, self.num_slots_to_gen, cfg_scale=self.cfg, cfg_schedule=self.cfg_schedule, temperature=self.temperature)
544
+ if self.num_slots_to_gen < self.num_slots:
545
+ null_slots = self.ae_model.dit.null_cond.expand(slots.shape[0], -1, -1)
546
+ null_slots = null_slots[:, self.num_slots_to_gen:, :]
547
+ slots = torch.cat([slots, null_slots], dim=1)
548
+ imgs = self.ae_model.sample(slots, targets=classes, cfg=self.ae_cfg) # targets are not used for now
549
+
550
+ imgs = concat_all_gather(imgs)
551
+ if self.accelerator.num_processes > 16:
552
+ imgs = imgs[:16*len(self.eval_classes)]
553
+ imgs = imgs.detach().cpu()
554
+ grid = make_grid(
555
+ imgs, nrow=len(self.eval_classes), normalize=True, value_range=(0, 1)
556
+ )
557
+ if self.accelerator.is_main_process:
558
+ save_image(
559
+ grid,
560
+ os.path.join(
561
+ self.image_saved_dir, f"step{self.steps}_aecfg-{self.ae_cfg}_cfg-{self.cfg_schedule}-{self.cfg}_slots{self.num_slots_to_gen}_temp{self.temperature}.jpg"
562
+ ),
563
+ )
564
+ if self.eval_fid and (self.test_only or (self.steps % self.fid_every == 0)):
565
+ # Create output directory (only on main process)
566
+ save_folder = os.path.join(self.image_saved_dir, f"gen_step{self.steps}_aecfg-{self.ae_cfg}_cfg-{self.cfg_schedule}-{self.cfg}_slots{self.num_slots_to_gen}_temp{self.temperature}")
567
+ if self.accelerator.is_main_process:
568
+ os.makedirs(save_folder, exist_ok=True)
569
+
570
+ # Setup for distributed generation
571
+ world_size = self.accelerator.num_processes
572
+ local_rank = self.accelerator.process_index
573
+ batch_size = self.test_bs
574
+
575
+ # Create balanced class distribution
576
+ num_classes = self.num_classes
577
+ images_per_class = self.num_test_images // num_classes
578
+ class_labels = np.repeat(np.arange(num_classes), images_per_class)
579
+
580
+ # Shuffle the class labels to ensure random ordering
581
+ np.random.shuffle(class_labels)
582
+
583
+ total_images = len(class_labels)
584
+
585
+ padding_size = world_size * batch_size - (total_images % (world_size * batch_size))
586
+ class_labels = np.pad(class_labels, (0, padding_size), 'constant')
587
+ padded_total_images = len(class_labels)
588
+
589
+ # Distribute workload across GPUs
590
+ images_per_gpu = padded_total_images // world_size
591
+ start_idx = local_rank * images_per_gpu
592
+ end_idx = min(start_idx + images_per_gpu, padded_total_images)
593
+ local_class_labels = class_labels[start_idx:end_idx]
594
+ local_num_steps = len(local_class_labels) // batch_size
595
+
596
+ if self.accelerator.is_main_process:
597
+ print(f"Generating {total_images} images ({images_per_class} per class) across {world_size} GPUs")
598
+
599
+ used_time = 0
600
+ gen_img_cnt = 0
601
+
602
+ for i in range(local_num_steps):
603
+ if self.accelerator.is_main_process and i % 10 == 0:
604
+ print(f"Generation step {i}/{local_num_steps}")
605
+
606
+ # Get and pad labels for current batch
607
+ batch_start = i * batch_size
608
+ batch_end = batch_start + batch_size
609
+ labels = local_class_labels[batch_start:batch_end]
610
+
611
+ # Convert to tensors and track real vs padding
612
+ labels = torch.tensor(labels, device=self.device)
613
+
614
+ # Generate images
615
+ self.accelerator.wait_for_everyone()
616
+ start_time = time.time()
617
+ with torch.no_grad():
618
+ with self.accelerator.autocast():
619
+ slots = generate(unwraped_gpt_model, labels, self.num_slots_to_gen,
620
+ cfg_scale=self.cfg,
621
+ cfg_schedule=self.cfg_schedule,
622
+ temperature=self.temperature)
623
+ if self.num_slots_to_gen < self.num_slots:
624
+ null_slots = self.ae_model.dit.null_cond.expand(slots.shape[0], -1, -1)
625
+ null_slots = null_slots[:, self.num_slots_to_gen:, :]
626
+ slots = torch.cat([slots, null_slots], dim=1)
627
+ imgs = self.ae_model.sample(slots, targets=labels, cfg=self.ae_cfg)
628
+
629
+ samples_in_batch = min(batch_size * world_size, total_images - gen_img_cnt)
630
+
631
+ # Update timing stats
632
+ used_time += time.time() - start_time
633
+ gen_img_cnt += samples_in_batch
634
+ if self.accelerator.is_main_process and i % 10 == 0:
635
+ print(f"Avg generation time: {used_time/gen_img_cnt:.5f} sec/image")
636
+
637
+ gathered_imgs = concat_all_gather(imgs)
638
+ gathered_imgs = gathered_imgs[:samples_in_batch]
639
+
640
+ # Save images (only on main process)
641
+ if self.accelerator.is_main_process:
642
+ real_imgs = gathered_imgs.detach().cpu()
643
+
644
+ save_paths = [
645
+ os.path.join(save_folder, f"{str(idx).zfill(5)}.png")
646
+ for idx in range(gen_img_cnt - samples_in_batch, gen_img_cnt)
647
+ ]
648
+ save_img_batch(real_imgs, save_paths)
649
+
650
+ # Calculate metrics (only on main process)
651
+ self.accelerator.wait_for_everyone()
652
+ if self.accelerator.is_main_process:
653
+ generated_files = len(os.listdir(save_folder))
654
+ print(f"Generated {generated_files} images out of {total_images} expected")
655
+
656
+ metrics_dict = get_fid_stats(save_folder, None, self.fid_stats)
657
+ fid = metrics_dict["frechet_inception_distance"]
658
+ inception_score = metrics_dict["inception_score_mean"]
659
+
660
+ metric_prefix = "fid_ema" if use_ema else "fid"
661
+ isc_prefix = "isc_ema" if use_ema else "isc"
662
+
663
+ self.accelerator.log({
664
+ metric_prefix: fid,
665
+ isc_prefix: inception_score,
666
+ "gpt_cfg": self.cfg,
667
+ "ae_cfg": self.ae_cfg,
668
+ "cfg_schedule": self.cfg_schedule,
669
+ "temperature": self.temperature,
670
+ "num_slots": self.test_num_slots if self.test_num_slots is not None else self.train_num_slots
671
+ }, step=self.steps)
672
+
673
+ # Print comprehensive CFG information
674
+ cfg_info = (
675
+ f"{'EMA ' if use_ema else ''}CFG params: "
676
+ f"gpt_cfg={self.cfg}, ae_cfg={self.ae_cfg}, "
677
+ f"cfg_schedule={self.cfg_schedule}, "
678
+ f"num_slots={self.test_num_slots if self.test_num_slots is not None else self.train_num_slots}, "
679
+ f"temperature={self.temperature}"
680
+ )
681
+ print(cfg_info)
682
+ print(f"FID: {fid:.2f}, ISC: {inception_score:.2f}")
683
+
684
+ # Cleanup
685
+ shutil.rmtree(save_folder)
686
+
687
+ # back to no ema
688
+ if use_ema:
689
+ if self.accelerator.is_main_process:
690
+ print("Switch back from ema")
691
+ model_without_ddp.load_state_dict(model_state_dict)
692
+
693
+ self.gpt_model.train()
694
+
semanticist/engine/trainer_utils.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, torch
2
+ import cv2
3
+ import numpy as np
4
+ import torch_fidelity
5
+ from collections import OrderedDict
6
+ from concurrent.futures import ThreadPoolExecutor
7
+ import importlib
8
+ from torch.optim import AdamW
9
+ from semanticist.utils.lr_scheduler import build_scheduler
10
+
11
+
12
+ def get_obj_from_str(string, reload=False):
13
+ """Get object from string path."""
14
+ module, cls = string.rsplit(".", 1)
15
+ if reload:
16
+ module_imp = importlib.import_module(module)
17
+ importlib.reload(module_imp)
18
+ return getattr(importlib.import_module(module, package=None), cls)
19
+
20
+
21
+ def instantiate_from_config(config):
22
+ """Instantiate an object from a config dictionary."""
23
+ if not "target" in config:
24
+ raise KeyError("Expected key `target` to instantiate.")
25
+ return get_obj_from_str(config["target"])(**config.get("params", dict()))
26
+
27
+
28
+ def is_dist_avail_and_initialized():
29
+ """Check if distributed training is available and initialized."""
30
+ if not torch.distributed.is_initialized():
31
+ return False
32
+ return True
33
+
34
+
35
+ def is_main_process():
36
+ """Check if the current process is the main process."""
37
+ return not is_dist_avail_and_initialized() or torch.distributed.get_rank() == 0
38
+
39
+
40
+ def concat_all_gather(tensor):
41
+ """
42
+ Performs all_gather operation on the provided tensors.
43
+ *** Warning ***: torch.distributed.all_gather has no gradient.
44
+ """
45
+ tensors_gather = [torch.ones_like(tensor)
46
+ for _ in range(torch.distributed.get_world_size())]
47
+ torch.distributed.all_gather(tensors_gather, tensor, async_op=False)
48
+
49
+ output = torch.cat(tensors_gather, dim=0)
50
+ return output
51
+
52
+
53
+ def requires_grad(model, flag=True):
54
+ """Set requires_grad flag for all model parameters."""
55
+ for p in model.parameters():
56
+ p.requires_grad = flag
57
+
58
+
59
+ def save_img(img, save_path):
60
+ """Save a single image to disk."""
61
+ img = np.clip(img.float().numpy().transpose([1, 2, 0]) * 255, 0, 255)
62
+ img = img.astype(np.uint8)[:, :, ::-1]
63
+ cv2.imwrite(save_path, img)
64
+
65
+
66
+ def save_img_batch(imgs, save_paths):
67
+ """Process and save multiple images at once using a thread pool."""
68
+ # Convert to numpy and prepare all images in one go
69
+ imgs = np.clip(imgs.float().numpy().transpose(0, 2, 3, 1) * 255, 0, 255).astype(np.uint8)
70
+ imgs = imgs[:, :, :, ::-1] # RGB to BGR for all images at once
71
+
72
+ with ThreadPoolExecutor(max_workers=32) as pool:
73
+ # Submit all tasks at once
74
+ futures = [pool.submit(cv2.imwrite, path, img)
75
+ for path, img in zip(save_paths, imgs)]
76
+ # Wait for all tasks to complete
77
+ for future in futures:
78
+ future.result() # This will raise any exceptions that occurred
79
+
80
+
81
+ def get_fid_stats(real_dir, rec_dir, fid_stats):
82
+ """Calculate FID statistics between real and reconstructed images."""
83
+ stats = torch_fidelity.calculate_metrics(
84
+ input1=rec_dir,
85
+ input2=real_dir,
86
+ fid_statistics_file=fid_stats,
87
+ cuda=True,
88
+ isc=True,
89
+ fid=True,
90
+ kid=False,
91
+ prc=False,
92
+ verbose=False,
93
+ )
94
+ return stats
95
+
96
+
97
+ def create_scheduler(optimizer, num_epoch, steps_per_epoch, lr_min, warmup_steps,
98
+ warmup_lr_init, decay_steps, cosine_lr):
99
+ """Create a learning rate scheduler."""
100
+ scheduler = build_scheduler(
101
+ optimizer,
102
+ num_epoch,
103
+ steps_per_epoch,
104
+ lr_min,
105
+ warmup_steps,
106
+ warmup_lr_init,
107
+ decay_steps,
108
+ cosine_lr,
109
+ )
110
+ return scheduler
111
+
112
+
113
+ def load_state_dict(state_dict, model):
114
+ """Helper to load a state dict with proper prefix handling."""
115
+ if 'state_dict' in state_dict:
116
+ state_dict = state_dict['state_dict']
117
+ # Remove '_orig_mod' prefix if present
118
+ state_dict = {k.replace('_orig_mod.', ''): v for k, v in state_dict.items()}
119
+ missing, unexpected = model.load_state_dict(
120
+ state_dict, strict=False
121
+ )
122
+ if is_main_process():
123
+ print(f"Loaded model. Missing: {missing}, Unexpected: {unexpected}")
124
+
125
+
126
+ def load_safetensors(path, model):
127
+ """Helper to load a safetensors checkpoint."""
128
+ from safetensors.torch import safe_open
129
+ with safe_open(path, framework="pt", device="cpu") as f:
130
+ state_dict = {k: f.get_tensor(k) for k in f.keys()}
131
+ load_state_dict(state_dict, model)
132
+
133
+
134
+ def setup_result_folders(result_folder):
135
+ """Setup result folders for saving models and images."""
136
+ model_saved_dir = os.path.join(result_folder, "models")
137
+ os.makedirs(model_saved_dir, exist_ok=True)
138
+
139
+ image_saved_dir = os.path.join(result_folder, "images")
140
+ os.makedirs(image_saved_dir, exist_ok=True)
141
+
142
+ return model_saved_dir, image_saved_dir
143
+
144
+
145
+ def create_optimizer(model, weight_decay, learning_rate, betas=(0.9, 0.95)):
146
+ """Create an AdamW optimizer with weight decay for 2D parameters only."""
147
+ # start with all of the candidate parameters
148
+ param_dict = {pn: p for pn, p in model.named_parameters()}
149
+ # filter out those that do not require grad
150
+ param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
151
+ # create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
152
+ # i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
153
+ decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
154
+ nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
155
+ optim_groups = [
156
+ {'params': decay_params, 'weight_decay': weight_decay},
157
+ {'params': nodecay_params, 'weight_decay': 0.0}
158
+ ]
159
+ num_decay_params = sum(p.numel() for p in decay_params)
160
+ num_nodecay_params = sum(p.numel() for p in nodecay_params)
161
+ if is_main_process():
162
+ print(f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters")
163
+ print(f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters")
164
+ optimizer = AdamW(optim_groups, lr=learning_rate, betas=betas)
165
+ return optimizer
166
+
167
+
168
+ class EMAModel:
169
+ """Model Exponential Moving Average."""
170
+ def __init__(self, model, device, decay=0.999):
171
+ self.device = device
172
+ self.decay = decay
173
+ self.ema_params = OrderedDict(
174
+ (name, param.clone().detach().to(device))
175
+ for name, param in model.named_parameters()
176
+ if param.requires_grad
177
+ )
178
+
179
+ @torch.no_grad()
180
+ def update(self, model):
181
+ for name, param in model.named_parameters():
182
+ if param.requires_grad:
183
+ if name in self.ema_params:
184
+ self.ema_params[name].lerp_(param.data, 1 - self.decay)
185
+ else:
186
+ self.ema_params[name] = param.data.clone().detach()
187
+
188
+ def state_dict(self):
189
+ return self.ema_params
190
+
191
+ def load_state_dict(self, params):
192
+ self.ema_params = OrderedDict(
193
+ (name, param.clone().detach().to(self.device))
194
+ for name, param in params.items()
195
+ )
196
+
197
+
198
+ class PaddedDataset(torch.utils.data.Dataset):
199
+ """Dataset wrapper that pads a dataset to ensure even distribution across processes."""
200
+ def __init__(self, dataset, padding_size):
201
+ self.dataset = dataset
202
+ self.padding_size = padding_size
203
+
204
+ def __len__(self):
205
+ return len(self.dataset) + self.padding_size
206
+
207
+ def __getitem__(self, idx):
208
+ if idx < len(self.dataset):
209
+ return self.dataset[idx]
210
+ return self.dataset[0]
211
+
212
+ class CacheDataLoader:
213
+ """DataLoader-like interface for cached data with epoch-based shuffling."""
214
+ def __init__(self, slots, targets=None, batch_size=32, num_augs=1, seed=None):
215
+ self.slots = slots
216
+ self.targets = targets
217
+ self.batch_size = batch_size
218
+ self.num_augs = num_augs
219
+ self.seed = seed
220
+ self.epoch = 0
221
+ # Original dataset size (before augmentations)
222
+ self.num_samples = len(slots) // num_augs
223
+
224
+ def set_epoch(self, epoch):
225
+ """Set epoch for deterministic shuffling."""
226
+ self.epoch = epoch
227
+
228
+ def __len__(self):
229
+ """Return number of batches based on original dataset size."""
230
+ return self.num_samples // self.batch_size
231
+
232
+ def __iter__(self):
233
+ """Return random indices for current epoch."""
234
+ g = torch.Generator()
235
+ g.manual_seed(self.seed + self.epoch if self.seed is not None else self.epoch)
236
+
237
+ # Randomly sample indices from the entire augmented dataset
238
+ indices = torch.randint(
239
+ 0, len(self.slots),
240
+ (self.num_samples,),
241
+ generator=g
242
+ ).numpy()
243
+
244
+ # Yield batches of indices
245
+ for start in range(0, self.num_samples, self.batch_size):
246
+ end = min(start + self.batch_size, self.num_samples)
247
+ batch_indices = indices[start:end]
248
+ yield (
249
+ torch.from_numpy(self.slots[batch_indices]),
250
+ torch.from_numpy(self.targets[batch_indices])
251
+ )
semanticist/stage1/diffuse_slot.py ADDED
@@ -0,0 +1,452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from diffusers import AutoencoderKL
7
+ from semanticist.stage1 import vision_transformer
8
+ from semanticist.stage1.diffusion import create_diffusion
9
+ from semanticist.stage1.diffusion_transfomer import DiT
10
+ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
11
+
12
+ class DiT_with_autoenc_cond(DiT):
13
+ def __init__(
14
+ self,
15
+ *args,
16
+ num_autoenc=32,
17
+ autoenc_dim=4,
18
+ use_repa=False,
19
+ z_dim=768,
20
+ encoder_depth=8,
21
+ projector_dim=2048,
22
+ **kwargs,
23
+ ):
24
+ super().__init__(*args, **kwargs)
25
+ self.autoenc_dim = autoenc_dim
26
+ self.hidden_size = kwargs["hidden_size"]
27
+ self.null_cond = nn.Parameter(torch.zeros(1, num_autoenc, autoenc_dim))
28
+ torch.nn.init.normal_(self.null_cond, std=.02)
29
+ self.autoenc_cond_embedder = nn.Linear(autoenc_dim, self.hidden_size)
30
+ self.y_embedder = nn.Identity()
31
+ self.cond_drop_prob = 0.1
32
+
33
+ self.use_repa = use_repa
34
+ self._repa_hook = None
35
+ self.encoder_depth = encoder_depth
36
+ if use_repa:
37
+ self.projector = build_mlp(self.hidden_size, projector_dim, z_dim)
38
+
39
+ def embed_cond(self, autoenc_cond, drop_mask=None):
40
+ # autoenc_cond: (N, K, D)
41
+ # drop_ids: (N)
42
+ # self.null_cond: (1, K, D)
43
+ batch_size = autoenc_cond.shape[0]
44
+ if drop_mask is None:
45
+ # randomly drop all conditions, for classifier-free guidance
46
+ if self.training:
47
+ drop_ids = (
48
+ torch.rand(batch_size, 1, 1, device=autoenc_cond.device)
49
+ < self.cond_drop_prob
50
+ )
51
+ autoenc_cond_drop = torch.where(drop_ids, self.null_cond, autoenc_cond)
52
+ else:
53
+ autoenc_cond_drop = autoenc_cond
54
+ else:
55
+ # randomly drop some conditions according to the drop_mask (N, K)
56
+ # True means keep
57
+ autoenc_cond_drop = torch.where(drop_mask[:, :, None], autoenc_cond, self.null_cond)
58
+ return self.autoenc_cond_embedder(autoenc_cond_drop)
59
+
60
+ def forward(self, x, t, autoenc_cond, drop_mask=None):
61
+ """
62
+ Forward pass of DiT.
63
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
64
+ t: (N,) tensor of diffusion timesteps
65
+ autoenc_cond: (N, K, D) tensor of autoencoder conditions (slots)
66
+ """
67
+ x = (
68
+ self.x_embedder(x) + self.pos_embed
69
+ ) # (N, T, D), where T = H * W / patch_size ** 2
70
+ c = self.t_embedder(t) # (N, D)
71
+ autoenc = self.embed_cond(autoenc_cond, drop_mask)
72
+ num_tokens = x.shape[1]
73
+ x = torch.cat((x, autoenc), dim=1)
74
+
75
+ for i, block in enumerate(self.blocks):
76
+ x = block(x, c) # (N, T, D)
77
+ if (i + 1) == self.encoder_depth and self.use_repa:
78
+ projected = self.projector(x)
79
+ self._repa_hook = projected[:, :num_tokens]
80
+
81
+ x = x[:, :num_tokens]
82
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
83
+ x = self.unpatchify(x) # (N, out_channels, H, W)
84
+ return x
85
+
86
+ def forward_with_cfg(self, x, t, autoenc_cond, drop_mask, y=None, cfg_scale=1.0):
87
+ """
88
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
89
+ """
90
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
91
+ half = x[: len(x) // 2]
92
+ combined = torch.cat([half, half], dim=0)
93
+ model_out = self.forward(combined, t, autoenc_cond, drop_mask)
94
+ eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
95
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
96
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
97
+ eps = torch.cat([half_eps, half_eps], dim=0)
98
+ return torch.cat([eps, rest], dim=1)
99
+
100
+ #################################################################################
101
+ # DiT Configs #
102
+ #################################################################################
103
+
104
+
105
+ def DiT_with_autoenc_cond_XL_2(**kwargs):
106
+ return DiT_with_autoenc_cond(
107
+ depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs
108
+ )
109
+
110
+
111
+ def DiT_with_autoenc_cond_XL_4(**kwargs):
112
+ return DiT_with_autoenc_cond(
113
+ depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs
114
+ )
115
+
116
+
117
+ def DiT_with_autoenc_cond_XL_8(**kwargs):
118
+ return DiT_with_autoenc_cond(
119
+ depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs
120
+ )
121
+
122
+
123
+ def DiT_with_autoenc_cond_L_2(**kwargs):
124
+ return DiT_with_autoenc_cond(
125
+ depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs
126
+ )
127
+
128
+
129
+ def DiT_with_autoenc_cond_L_4(**kwargs):
130
+ return DiT_with_autoenc_cond(
131
+ depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs
132
+ )
133
+
134
+
135
+ def DiT_with_autoenc_cond_L_8(**kwargs):
136
+ return DiT_with_autoenc_cond(
137
+ depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs
138
+ )
139
+
140
+
141
+ def DiT_with_autoenc_cond_B_2(**kwargs):
142
+ return DiT_with_autoenc_cond(
143
+ depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs
144
+ )
145
+
146
+
147
+ def DiT_with_autoenc_cond_B_4(**kwargs):
148
+ return DiT_with_autoenc_cond(
149
+ depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs
150
+ )
151
+
152
+
153
+ def DiT_with_autoenc_cond_B_8(**kwargs):
154
+ return DiT_with_autoenc_cond(
155
+ depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs
156
+ )
157
+
158
+
159
+ def DiT_with_autoenc_cond_S_2(**kwargs):
160
+ return DiT_with_autoenc_cond(
161
+ depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs
162
+ )
163
+
164
+
165
+ def DiT_with_autoenc_cond_S_4(**kwargs):
166
+ return DiT_with_autoenc_cond(
167
+ depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs
168
+ )
169
+
170
+
171
+ def DiT_with_autoenc_cond_S_8(**kwargs):
172
+ return DiT_with_autoenc_cond(
173
+ depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs
174
+ )
175
+
176
+
177
+ DiT_with_autoenc_cond_models = {
178
+ "DiT-XL-2": DiT_with_autoenc_cond_XL_2,
179
+ "DiT-XL-4": DiT_with_autoenc_cond_XL_4,
180
+ "DiT-XL-8": DiT_with_autoenc_cond_XL_8,
181
+ "DiT-L-2": DiT_with_autoenc_cond_L_2,
182
+ "DiT-L-4": DiT_with_autoenc_cond_L_4,
183
+ "DiT-L-8": DiT_with_autoenc_cond_L_8,
184
+ "DiT-B-2": DiT_with_autoenc_cond_B_2,
185
+ "DiT-B-4": DiT_with_autoenc_cond_B_4,
186
+ "DiT-B-8": DiT_with_autoenc_cond_B_8,
187
+ "DiT-S-2": DiT_with_autoenc_cond_S_2,
188
+ "DiT-S-4": DiT_with_autoenc_cond_S_4,
189
+ "DiT-S-8": DiT_with_autoenc_cond_S_8,
190
+ }
191
+
192
+ class NestedSampler(nn.Module):
193
+ def __init__(
194
+ self,
195
+ num_slots,
196
+ ):
197
+ super().__init__()
198
+ self.num_slots = num_slots
199
+ self.register_buffer("arange", torch.arange(num_slots))
200
+
201
+ def uniform_sample(self, num):
202
+ return torch.randint(1, self.num_slots + 1, (num,))
203
+
204
+ def sample(self, num):
205
+ samples = self.uniform_sample(num)
206
+ return samples
207
+
208
+ def forward(self, batch_size, device, inference_with_n_slots=-1):
209
+ if self.training:
210
+ b = self.sample(batch_size).to(device)
211
+ else:
212
+ if inference_with_n_slots != -1:
213
+ b = torch.full((batch_size,), inference_with_n_slots, device=device)
214
+ else:
215
+ b = torch.full((batch_size,), self.num_slots, device=device)
216
+ b = torch.clamp(b, max=self.num_slots)
217
+
218
+ slot_mask = self.arange[None, :] < b[:, None] # (batch_size, num_slots)
219
+ return slot_mask
220
+
221
+ class DiffuseSlot(nn.Module):
222
+ def __init__(
223
+ self,
224
+ encoder="vit_base_patch16",
225
+ drop_path_rate=0.1,
226
+ enc_img_size=256,
227
+ enc_causal=True,
228
+ num_slots=16,
229
+ slot_dim=256,
230
+ norm_slots=False,
231
+ enable_nest=False,
232
+ enable_nest_after=-1,
233
+ vae="stabilityai/sd-vae-ft-ema",
234
+ dit_model="DiT-B-4",
235
+ num_sampling_steps="ddim25",
236
+ use_repa=False,
237
+ repa_encoder_depth=8,
238
+ repa_loss_weight=1.0,
239
+ **kwargs,
240
+ ):
241
+ super().__init__()
242
+
243
+ self.use_repa = use_repa
244
+ self.repa_loss_weight = repa_loss_weight
245
+ if use_repa:
246
+ self.repa_encoder = torch.hub.load('facebookresearch/dinov2', 'dinov2_vitb14')
247
+ self.repa_encoder.image_size = 224
248
+ for param in self.repa_encoder.parameters():
249
+ param.requires_grad = False
250
+ self.repa_encoder.eval()
251
+
252
+ self.diffusion = create_diffusion(timestep_respacing="")
253
+ self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps)
254
+ self.dit_input_size = enc_img_size // 8 if not "mar" in vae else enc_img_size // 16
255
+ self.dit_in_channels = 4 if not "mar" in vae else 16
256
+ self.dit = DiT_with_autoenc_cond_models[dit_model](
257
+ input_size=self.dit_input_size,
258
+ in_channels=self.dit_in_channels,
259
+ num_autoenc=num_slots,
260
+ autoenc_dim=slot_dim,
261
+ use_repa=use_repa,
262
+ encoder_depth=repa_encoder_depth,
263
+ z_dim=768,
264
+ )
265
+ self.vae = AutoencoderKL.from_pretrained(vae)
266
+ self.scaling_factor = self.vae.config.scaling_factor
267
+ self.vae.eval().requires_grad_(False)
268
+
269
+ self.enc_img_size = enc_img_size
270
+ self.enc_causal = enc_causal
271
+ encoder_fn = vision_transformer.__dict__[encoder]
272
+
273
+ self.encoder = encoder_fn(
274
+ img_size=[enc_img_size],
275
+ num_slots=num_slots,
276
+ drop_path_rate=drop_path_rate,
277
+ )
278
+ self.num_slots = num_slots
279
+ self.norm_slots = norm_slots
280
+ self.num_channels = self.encoder.num_features
281
+
282
+ self.encoder2slot = nn.Linear(self.num_channels, slot_dim)
283
+ self.nested_sampler = NestedSampler(num_slots)
284
+ self.enable_nest = enable_nest
285
+ self.enable_nest_after = enable_nest_after
286
+
287
+ @torch.no_grad()
288
+ def vae_encode(self, x):
289
+ x = x * 2 - 1
290
+ x = self.vae.encode(x)
291
+ if hasattr(x, 'latent_dist'):
292
+ x = x.latent_dist
293
+ return x.sample().mul_(self.scaling_factor)
294
+
295
+ @torch.no_grad()
296
+ def vae_decode(self, z):
297
+ z = self.vae.decode(z / self.scaling_factor)
298
+ if hasattr(z, 'sample'):
299
+ z = z.sample
300
+ return (z + 1) / 2
301
+
302
+ @torch.no_grad()
303
+ def repa_encode(self, x):
304
+ mean = torch.Tensor(IMAGENET_DEFAULT_MEAN).to(x.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
305
+ std = torch.Tensor(IMAGENET_DEFAULT_STD).to(x.device).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
306
+ x = (x - mean) / std
307
+ if self.repa_encoder.image_size != self.enc_img_size:
308
+ x = torch.nn.functional.interpolate(x, self.repa_encoder.image_size, mode='bicubic')
309
+ x = self.repa_encoder.forward_features(x)['x_norm_patchtokens']
310
+ return x
311
+
312
+ def encode_slots(self, x):
313
+ slots = self.encoder(x, is_causal=self.enc_causal)
314
+ slots = self.encoder2slot(slots)
315
+ if self.norm_slots:
316
+ slots_std = torch.std(slots, dim=-1, keepdim=True)
317
+ slots_mean = torch.mean(slots, dim=-1, keepdim=True)
318
+ slots = (slots - slots_mean) / slots_std
319
+ return slots
320
+
321
+ def forward_with_latents(self,
322
+ x_vae,
323
+ slots,
324
+ z,
325
+ sample=False,
326
+ epoch=None,
327
+ inference_with_n_slots=-1,
328
+ cfg=1.0):
329
+ losses = {}
330
+ batch_size = x_vae.shape[0]
331
+ device = x_vae.device
332
+
333
+ if (
334
+ epoch is not None
335
+ and epoch >= self.enable_nest_after
336
+ and self.enable_nest_after != -1
337
+ ):
338
+ self.enable_nest = True
339
+
340
+ t = torch.randint(0, 1000, (x_vae.shape[0],), device=device)
341
+
342
+ if self.enable_nest or inference_with_n_slots != -1:
343
+ drop_mask = self.nested_sampler(
344
+ batch_size, device,
345
+ inference_with_n_slots=inference_with_n_slots,
346
+ )
347
+ else:
348
+ drop_mask = None
349
+
350
+ if sample:
351
+ return self.sample(slots, drop_mask=drop_mask, cfg=cfg)
352
+
353
+ model_kwargs = dict(autoenc_cond=slots, drop_mask=drop_mask)
354
+ loss_dict = self.diffusion.training_losses(self.dit, x_vae, t, model_kwargs)
355
+ diff_loss = loss_dict["loss"].mean()
356
+ losses["diff_loss"] = diff_loss
357
+
358
+ if self.use_repa:
359
+ assert self.dit._repa_hook is not None and z is not None
360
+ z_tilde = self.dit._repa_hook
361
+
362
+ if z_tilde.shape[1] != z.shape[1]:
363
+ z_tilde = interpolate_features(z_tilde, z.shape[1])
364
+
365
+ z_tilde = F.normalize(z_tilde, dim=-1)
366
+ z = F.normalize(z, dim=-1)
367
+ repa_loss = -torch.sum(z_tilde * z, dim=-1)
368
+ losses["repa_loss"] = repa_loss.mean() * self.repa_loss_weight
369
+
370
+ return losses
371
+
372
+
373
+ def forward(self,
374
+ x,
375
+ sample=False,
376
+ epoch=None,
377
+ inference_with_n_slots=-1,
378
+ cfg=1.0):
379
+
380
+ x_vae = self.vae_encode(x)
381
+ z = self.repa_encode(x) if self.use_repa else None
382
+ slots = self.encode_slots(x)
383
+ return self.forward_with_latents(x_vae, slots, z, sample, epoch, inference_with_n_slots, cfg)
384
+
385
+
386
+ @torch.no_grad()
387
+ def sample(self, slots, drop_mask=None, cfg=1.0):
388
+ batch_size = slots.shape[0]
389
+ device = slots.device
390
+ z = torch.randn(batch_size, self.dit_in_channels, self.dit_input_size, self.dit_input_size, device=device)
391
+ if cfg != 1.0:
392
+ z = torch.cat([z, z], 0)
393
+ null_slots = self.dit.null_cond.expand(batch_size, -1, -1)
394
+ slots = torch.cat([slots, null_slots], 0)
395
+ if drop_mask is not None:
396
+ null_cond_mask = torch.ones_like(drop_mask)
397
+ drop_mask = torch.cat([drop_mask, null_cond_mask], 0)
398
+ model_kwargs = dict(autoenc_cond=slots, drop_mask=drop_mask, cfg_scale=cfg)
399
+ sample_fn = self.dit.forward_with_cfg
400
+ else:
401
+ model_kwargs = dict(autoenc_cond=slots, drop_mask=drop_mask)
402
+ sample_fn = self.dit.forward
403
+ samples = self.gen_diffusion.p_sample_loop(
404
+ sample_fn,
405
+ z.shape,
406
+ z,
407
+ clip_denoised=False,
408
+ model_kwargs=model_kwargs,
409
+ progress=False,
410
+ device=device,
411
+ )
412
+ if cfg != 1.0:
413
+ samples, _ = samples.chunk(2, dim=0) # Remove null class samples
414
+ samples = self.vae_decode(samples)
415
+ return samples
416
+
417
+ def train(self, mode=True):
418
+ """Override train() to keep certain components in eval mode"""
419
+ super().train(mode)
420
+ self.vae.eval()
421
+ return self
422
+
423
+
424
+ def build_mlp(hidden_size, projector_dim, z_dim):
425
+ return nn.Sequential(
426
+ nn.Linear(hidden_size, projector_dim),
427
+ nn.SiLU(),
428
+ nn.Linear(projector_dim, projector_dim),
429
+ nn.SiLU(),
430
+ nn.Linear(projector_dim, z_dim),
431
+ )
432
+
433
+ def interpolate_features(x, target_len):
434
+ """Interpolate features to match target sequence length.
435
+ Args:
436
+ x: tensor of shape (B, T1, D)
437
+ target_len: desired sequence length T2
438
+ Returns:
439
+ tensor of shape (B, T2, D)
440
+ """
441
+ B, T1, D = x.shape
442
+ H1 = W1 = int(math.sqrt(T1))
443
+ H2 = W2 = int(math.sqrt(target_len))
444
+
445
+ # Reshape to 2D spatial dimensions and move channels to second dimension
446
+ x = x.reshape(B, H1, W1, D).permute(0, 3, 1, 2)
447
+
448
+ # Interpolate
449
+ x = F.interpolate(x, size=(H2, W2), mode='bicubic', align_corners=False)
450
+
451
+ # Reshape back to sequence
452
+ return x.permute(0, 2, 3, 1).reshape(B, target_len, D)
semanticist/stage1/diffusion/__init__.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from . import gaussian_diffusion as gd
7
+ from .respace import SpacedDiffusion, space_timesteps
8
+
9
+
10
+ def create_diffusion(
11
+ timestep_respacing,
12
+ noise_schedule="linear",
13
+ use_kl=False,
14
+ sigma_small=False,
15
+ predict_xstart=False,
16
+ learn_sigma=True,
17
+ rescale_learned_sigmas=False,
18
+ diffusion_steps=1000
19
+ ):
20
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
21
+ if use_kl:
22
+ loss_type = gd.LossType.RESCALED_KL
23
+ elif rescale_learned_sigmas:
24
+ loss_type = gd.LossType.RESCALED_MSE
25
+ else:
26
+ loss_type = gd.LossType.MSE
27
+ if timestep_respacing is None or timestep_respacing == "":
28
+ timestep_respacing = [diffusion_steps]
29
+ return SpacedDiffusion(
30
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
31
+ betas=betas,
32
+ model_mean_type=(
33
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
34
+ ),
35
+ model_var_type=(
36
+ (
37
+ gd.ModelVarType.FIXED_LARGE
38
+ if not sigma_small
39
+ else gd.ModelVarType.FIXED_SMALL
40
+ )
41
+ if not learn_sigma
42
+ else gd.ModelVarType.LEARNED_RANGE
43
+ ),
44
+ loss_type=loss_type
45
+ # rescale_timesteps=rescale_timesteps,
46
+ )
semanticist/stage1/diffusion/diffusion_utils.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import torch as th
7
+ import numpy as np
8
+
9
+
10
+ def normal_kl(mean1, logvar1, mean2, logvar2):
11
+ """
12
+ Compute the KL divergence between two gaussians.
13
+ Shapes are automatically broadcasted, so batches can be compared to
14
+ scalars, among other use cases.
15
+ """
16
+ tensor = None
17
+ for obj in (mean1, logvar1, mean2, logvar2):
18
+ if isinstance(obj, th.Tensor):
19
+ tensor = obj
20
+ break
21
+ assert tensor is not None, "at least one argument must be a Tensor"
22
+
23
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
24
+ # Tensors, but it does not work for th.exp().
25
+ logvar1, logvar2 = [
26
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
27
+ for x in (logvar1, logvar2)
28
+ ]
29
+
30
+ return 0.5 * (
31
+ -1.0
32
+ + logvar2
33
+ - logvar1
34
+ + th.exp(logvar1 - logvar2)
35
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
36
+ )
37
+
38
+
39
+ def approx_standard_normal_cdf(x):
40
+ """
41
+ A fast approximation of the cumulative distribution function of the
42
+ standard normal.
43
+ """
44
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
45
+
46
+
47
+ def continuous_gaussian_log_likelihood(x, *, means, log_scales):
48
+ """
49
+ Compute the log-likelihood of a continuous Gaussian distribution.
50
+ :param x: the targets
51
+ :param means: the Gaussian mean Tensor.
52
+ :param log_scales: the Gaussian log stddev Tensor.
53
+ :return: a tensor like x of log probabilities (in nats).
54
+ """
55
+ centered_x = x - means
56
+ inv_stdv = th.exp(-log_scales)
57
+ normalized_x = centered_x * inv_stdv
58
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
59
+ return log_probs
60
+
61
+
62
+ def discretized_gaussian_log_likelihood(x, *, means, log_scales):
63
+ """
64
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
65
+ given image.
66
+ :param x: the target images. It is assumed that this was uint8 values,
67
+ rescaled to the range [-1, 1].
68
+ :param means: the Gaussian mean Tensor.
69
+ :param log_scales: the Gaussian log stddev Tensor.
70
+ :return: a tensor like x of log probabilities (in nats).
71
+ """
72
+ assert x.shape == means.shape == log_scales.shape
73
+ centered_x = x - means
74
+ inv_stdv = th.exp(-log_scales)
75
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
76
+ cdf_plus = approx_standard_normal_cdf(plus_in)
77
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
78
+ cdf_min = approx_standard_normal_cdf(min_in)
79
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
80
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
81
+ cdf_delta = cdf_plus - cdf_min
82
+ log_probs = th.where(
83
+ x < -0.999,
84
+ log_cdf_plus,
85
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
86
+ )
87
+ assert log_probs.shape == x.shape
88
+ return log_probs
semanticist/stage1/diffusion/gaussian_diffusion.py ADDED
@@ -0,0 +1,886 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+
7
+ import math
8
+
9
+ import numpy as np
10
+ import torch as th
11
+ import enum
12
+
13
+ from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
14
+
15
+
16
+ def mean_flat(tensor):
17
+ """
18
+ Take the mean over all non-batch dimensions.
19
+ """
20
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
21
+
22
+
23
+ class ModelMeanType(enum.Enum):
24
+ """
25
+ Which type of output the model predicts.
26
+ """
27
+
28
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
29
+ START_X = enum.auto() # the model predicts x_0
30
+ EPSILON = enum.auto() # the model predicts epsilon
31
+
32
+
33
+ class ModelVarType(enum.Enum):
34
+ """
35
+ What is used as the model's output variance.
36
+ The LEARNED_RANGE option has been added to allow the model to predict
37
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
38
+ """
39
+
40
+ LEARNED = enum.auto()
41
+ FIXED_SMALL = enum.auto()
42
+ FIXED_LARGE = enum.auto()
43
+ LEARNED_RANGE = enum.auto()
44
+
45
+
46
+ class LossType(enum.Enum):
47
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
48
+ RESCALED_MSE = (
49
+ enum.auto()
50
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
51
+ KL = enum.auto() # use the variational lower-bound
52
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
53
+
54
+ def is_vb(self):
55
+ return self == LossType.KL or self == LossType.RESCALED_KL
56
+
57
+
58
+ def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
59
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
60
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
61
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
62
+ return betas
63
+
64
+
65
+ def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
66
+ """
67
+ This is the deprecated API for creating beta schedules.
68
+ See get_named_beta_schedule() for the new library of schedules.
69
+ """
70
+ if beta_schedule == "quad":
71
+ betas = (
72
+ np.linspace(
73
+ beta_start ** 0.5,
74
+ beta_end ** 0.5,
75
+ num_diffusion_timesteps,
76
+ dtype=np.float64,
77
+ )
78
+ ** 2
79
+ )
80
+ elif beta_schedule == "linear":
81
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
82
+ elif beta_schedule == "warmup10":
83
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
84
+ elif beta_schedule == "warmup50":
85
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
86
+ elif beta_schedule == "const":
87
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
88
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
89
+ betas = 1.0 / np.linspace(
90
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
91
+ )
92
+ else:
93
+ raise NotImplementedError(beta_schedule)
94
+ assert betas.shape == (num_diffusion_timesteps,)
95
+ return betas
96
+
97
+
98
+ def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
99
+ """
100
+ Get a pre-defined beta schedule for the given name.
101
+ The beta schedule library consists of beta schedules which remain similar
102
+ in the limit of num_diffusion_timesteps.
103
+ Beta schedules may be added, but should not be removed or changed once
104
+ they are committed to maintain backwards compatibility.
105
+ """
106
+ if schedule_name == "linear":
107
+ # Linear schedule from Ho et al, extended to work for any number of
108
+ # diffusion steps.
109
+ scale = 1000 / num_diffusion_timesteps
110
+ return get_beta_schedule(
111
+ "linear",
112
+ beta_start=scale * 0.0001,
113
+ beta_end=scale * 0.02,
114
+ num_diffusion_timesteps=num_diffusion_timesteps,
115
+ )
116
+ elif schedule_name == "cosine":
117
+ return betas_for_alpha_bar(
118
+ num_diffusion_timesteps,
119
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
120
+ )
121
+ else:
122
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
123
+
124
+
125
+ def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
126
+ """
127
+ Create a beta schedule that discretizes the given alpha_t_bar function,
128
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
129
+ :param num_diffusion_timesteps: the number of betas to produce.
130
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
131
+ produces the cumulative product of (1-beta) up to that
132
+ part of the diffusion process.
133
+ :param max_beta: the maximum beta to use; use values lower than 1 to
134
+ prevent singularities.
135
+ """
136
+ betas = []
137
+ for i in range(num_diffusion_timesteps):
138
+ t1 = i / num_diffusion_timesteps
139
+ t2 = (i + 1) / num_diffusion_timesteps
140
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
141
+ return np.array(betas)
142
+
143
+
144
+ class GaussianDiffusion:
145
+ """
146
+ Utilities for training and sampling diffusion models.
147
+ Original ported from this codebase:
148
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
149
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
150
+ starting at T and going to 1.
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ *,
156
+ betas,
157
+ model_mean_type,
158
+ model_var_type,
159
+ loss_type
160
+ ):
161
+
162
+ self.model_mean_type = model_mean_type
163
+ self.model_var_type = model_var_type
164
+ self.loss_type = loss_type
165
+
166
+ # Use float64 for accuracy.
167
+ betas = np.array(betas, dtype=np.float64)
168
+ self.betas = betas
169
+ assert len(betas.shape) == 1, "betas must be 1-D"
170
+ assert (betas > 0).all() and (betas <= 1).all()
171
+
172
+ self.num_timesteps = int(betas.shape[0])
173
+
174
+ alphas = 1.0 - betas
175
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
176
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
177
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
178
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
179
+
180
+ # calculations for diffusion q(x_t | x_{t-1}) and others
181
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
182
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
183
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
184
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
185
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
186
+
187
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
188
+ self.posterior_variance = (
189
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
190
+ )
191
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
192
+ self.posterior_log_variance_clipped = np.log(
193
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
194
+ ) if len(self.posterior_variance) > 1 else np.array([])
195
+
196
+ self.posterior_mean_coef1 = (
197
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
198
+ )
199
+ self.posterior_mean_coef2 = (
200
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
201
+ )
202
+
203
+ def q_mean_variance(self, x_start, t):
204
+ """
205
+ Get the distribution q(x_t | x_0).
206
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
207
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
208
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
209
+ """
210
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
211
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
212
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
213
+ return mean, variance, log_variance
214
+
215
+ def q_sample(self, x_start, t, noise=None):
216
+ """
217
+ Diffuse the data for a given number of diffusion steps.
218
+ In other words, sample from q(x_t | x_0).
219
+ :param x_start: the initial data batch.
220
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
221
+ :param noise: if specified, the split-out normal noise.
222
+ :return: A noisy version of x_start.
223
+ """
224
+ if noise is None:
225
+ noise = th.randn_like(x_start)
226
+ assert noise.shape == x_start.shape
227
+ return (
228
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
229
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
230
+ )
231
+
232
+ def q_posterior_mean_variance(self, x_start, x_t, t):
233
+ """
234
+ Compute the mean and variance of the diffusion posterior:
235
+ q(x_{t-1} | x_t, x_0)
236
+ """
237
+ assert x_start.shape == x_t.shape
238
+ posterior_mean = (
239
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
240
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
241
+ )
242
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
243
+ posterior_log_variance_clipped = _extract_into_tensor(
244
+ self.posterior_log_variance_clipped, t, x_t.shape
245
+ )
246
+ assert (
247
+ posterior_mean.shape[0]
248
+ == posterior_variance.shape[0]
249
+ == posterior_log_variance_clipped.shape[0]
250
+ == x_start.shape[0]
251
+ )
252
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
253
+
254
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
255
+ """
256
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
257
+ the initial x, x_0.
258
+ :param model: the model, which takes a signal and a batch of timesteps
259
+ as input.
260
+ :param x: the [N x C x ...] tensor at time t.
261
+ :param t: a 1-D Tensor of timesteps.
262
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
263
+ :param denoised_fn: if not None, a function which applies to the
264
+ x_start prediction before it is used to sample. Applies before
265
+ clip_denoised.
266
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
267
+ pass to the model. This can be used for conditioning.
268
+ :return: a dict with the following keys:
269
+ - 'mean': the model mean output.
270
+ - 'variance': the model variance output.
271
+ - 'log_variance': the log of 'variance'.
272
+ - 'pred_xstart': the prediction for x_0.
273
+ """
274
+ if model_kwargs is None:
275
+ model_kwargs = {}
276
+
277
+ B, C = x.shape[:2]
278
+ assert t.shape == (B,)
279
+ model_output = model(x, t, **model_kwargs)
280
+ if isinstance(model_output, tuple):
281
+ model_output, extra = model_output
282
+ else:
283
+ extra = None
284
+
285
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
286
+ assert model_output.shape == (B, C * 2, *x.shape[2:])
287
+ model_output, model_var_values = th.split(model_output, C, dim=1)
288
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
289
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
290
+ # The model_var_values is [-1, 1] for [min_var, max_var].
291
+ frac = (model_var_values + 1) / 2
292
+ model_log_variance = frac * max_log + (1 - frac) * min_log
293
+ model_variance = th.exp(model_log_variance)
294
+ else:
295
+ model_variance, model_log_variance = {
296
+ # for fixedlarge, we set the initial (log-)variance like so
297
+ # to get a better decoder log likelihood.
298
+ ModelVarType.FIXED_LARGE: (
299
+ np.append(self.posterior_variance[1], self.betas[1:]),
300
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
301
+ ),
302
+ ModelVarType.FIXED_SMALL: (
303
+ self.posterior_variance,
304
+ self.posterior_log_variance_clipped,
305
+ ),
306
+ }[self.model_var_type]
307
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
308
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
309
+
310
+ def process_xstart(x):
311
+ if denoised_fn is not None:
312
+ x = denoised_fn(x)
313
+ if clip_denoised:
314
+ return x.clamp(-1, 1)
315
+ return x
316
+
317
+ if self.model_mean_type == ModelMeanType.START_X:
318
+ pred_xstart = process_xstart(model_output)
319
+ else:
320
+ pred_xstart = process_xstart(
321
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
322
+ )
323
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
324
+
325
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
326
+ return {
327
+ "mean": model_mean,
328
+ "variance": model_variance,
329
+ "log_variance": model_log_variance,
330
+ "pred_xstart": pred_xstart,
331
+ "extra": extra,
332
+ }
333
+
334
+ def _predict_xstart_from_eps(self, x_t, t, eps):
335
+ assert x_t.shape == eps.shape
336
+ return (
337
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
338
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
339
+ )
340
+
341
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
342
+ return (
343
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
344
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
345
+
346
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
347
+ """
348
+ Compute the mean for the previous step, given a function cond_fn that
349
+ computes the gradient of a conditional log probability with respect to
350
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
351
+ condition on y.
352
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
353
+ """
354
+ gradient = cond_fn(x, t, **model_kwargs)
355
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
356
+ return new_mean
357
+
358
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
359
+ """
360
+ Compute what the p_mean_variance output would have been, should the
361
+ model's score function be conditioned by cond_fn.
362
+ See condition_mean() for details on cond_fn.
363
+ Unlike condition_mean(), this instead uses the conditioning strategy
364
+ from Song et al (2020).
365
+ """
366
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
367
+
368
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
369
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
370
+
371
+ out = p_mean_var.copy()
372
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
373
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
374
+ return out
375
+
376
+ def p_sample(
377
+ self,
378
+ model,
379
+ x,
380
+ t,
381
+ clip_denoised=True,
382
+ denoised_fn=None,
383
+ cond_fn=None,
384
+ model_kwargs=None,
385
+ temperature=1.0
386
+ ):
387
+ """
388
+ Sample x_{t-1} from the model at the given timestep.
389
+ :param model: the model to sample from.
390
+ :param x: the current tensor at x_{t-1}.
391
+ :param t: the value of t, starting at 0 for the first diffusion step.
392
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
393
+ :param denoised_fn: if not None, a function which applies to the
394
+ x_start prediction before it is used to sample.
395
+ :param cond_fn: if not None, this is a gradient function that acts
396
+ similarly to the model.
397
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
398
+ pass to the model. This can be used for conditioning.
399
+ :param temperature: temperature scaling during Diff Loss sampling.
400
+ :return: a dict containing the following keys:
401
+ - 'sample': a random sample from the model.
402
+ - 'pred_xstart': a prediction of x_0.
403
+ """
404
+ out = self.p_mean_variance(
405
+ model,
406
+ x,
407
+ t,
408
+ clip_denoised=clip_denoised,
409
+ denoised_fn=denoised_fn,
410
+ model_kwargs=model_kwargs,
411
+ )
412
+ noise = th.randn_like(x)
413
+ nonzero_mask = (
414
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
415
+ ) # no noise when t == 0
416
+ if cond_fn is not None:
417
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
418
+ # scale the noise by temperature
419
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise * temperature
420
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
421
+
422
+ def p_sample_loop(
423
+ self,
424
+ model,
425
+ shape,
426
+ noise=None,
427
+ clip_denoised=True,
428
+ denoised_fn=None,
429
+ cond_fn=None,
430
+ model_kwargs=None,
431
+ device=None,
432
+ progress=False,
433
+ temperature=1.0,
434
+ ):
435
+ """
436
+ Generate samples from the model.
437
+ :param model: the model module.
438
+ :param shape: the shape of the samples, (N, C, H, W).
439
+ :param noise: if specified, the noise from the encoder to sample.
440
+ Should be of the same shape as `shape`.
441
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
442
+ :param denoised_fn: if not None, a function which applies to the
443
+ x_start prediction before it is used to sample.
444
+ :param cond_fn: if not None, this is a gradient function that acts
445
+ similarly to the model.
446
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
447
+ pass to the model. This can be used for conditioning.
448
+ :param device: if specified, the device to create the samples on.
449
+ If not specified, use a model parameter's device.
450
+ :param progress: if True, show a tqdm progress bar.
451
+ :param temperature: temperature scaling during Diff Loss sampling.
452
+ :return: a non-differentiable batch of samples.
453
+ """
454
+ final = None
455
+ for sample in self.p_sample_loop_progressive(
456
+ model,
457
+ shape,
458
+ noise=noise,
459
+ clip_denoised=clip_denoised,
460
+ denoised_fn=denoised_fn,
461
+ cond_fn=cond_fn,
462
+ model_kwargs=model_kwargs,
463
+ device=device,
464
+ progress=progress,
465
+ temperature=temperature,
466
+ ):
467
+ final = sample
468
+ return final["sample"]
469
+
470
+ def p_sample_loop_progressive(
471
+ self,
472
+ model,
473
+ shape,
474
+ noise=None,
475
+ clip_denoised=True,
476
+ denoised_fn=None,
477
+ cond_fn=None,
478
+ model_kwargs=None,
479
+ device=None,
480
+ progress=False,
481
+ temperature=1.0,
482
+ ):
483
+ """
484
+ Generate samples from the model and yield intermediate samples from
485
+ each timestep of diffusion.
486
+ Arguments are the same as p_sample_loop().
487
+ Returns a generator over dicts, where each dict is the return value of
488
+ p_sample().
489
+ """
490
+ if device is None:
491
+ device = next(model.parameters()).device
492
+ assert isinstance(shape, (tuple, list))
493
+ if noise is not None:
494
+ img = noise
495
+ else:
496
+ img = th.randn(*shape, device=device)
497
+ indices = list(range(self.num_timesteps))[::-1]
498
+
499
+ if progress:
500
+ # Lazy import so that we don't depend on tqdm.
501
+ from tqdm.auto import tqdm
502
+
503
+ indices = tqdm(indices)
504
+
505
+ for i in indices:
506
+ t = th.tensor([i] * shape[0], device=device)
507
+ with th.no_grad():
508
+ out = self.p_sample(
509
+ model,
510
+ img,
511
+ t,
512
+ clip_denoised=clip_denoised,
513
+ denoised_fn=denoised_fn,
514
+ cond_fn=cond_fn,
515
+ model_kwargs=model_kwargs,
516
+ temperature=temperature,
517
+ )
518
+ yield out
519
+ img = out["sample"]
520
+
521
+ def ddim_sample(
522
+ self,
523
+ model,
524
+ x,
525
+ t,
526
+ clip_denoised=True,
527
+ denoised_fn=None,
528
+ cond_fn=None,
529
+ model_kwargs=None,
530
+ eta=0.0,
531
+ ):
532
+ """
533
+ Sample x_{t-1} from the model using DDIM.
534
+ Same usage as p_sample().
535
+ """
536
+ out = self.p_mean_variance(
537
+ model,
538
+ x,
539
+ t,
540
+ clip_denoised=clip_denoised,
541
+ denoised_fn=denoised_fn,
542
+ model_kwargs=model_kwargs,
543
+ )
544
+ if cond_fn is not None:
545
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
546
+
547
+ # Usually our model outputs epsilon, but we re-derive it
548
+ # in case we used x_start or x_prev prediction.
549
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
550
+
551
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
552
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
553
+ sigma = (
554
+ eta
555
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
556
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
557
+ )
558
+ # Equation 12.
559
+ noise = th.randn_like(x)
560
+ mean_pred = (
561
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
562
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
563
+ )
564
+ nonzero_mask = (
565
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
566
+ ) # no noise when t == 0
567
+ sample = mean_pred + nonzero_mask * sigma * noise
568
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
569
+
570
+ def ddim_reverse_sample(
571
+ self,
572
+ model,
573
+ x,
574
+ t,
575
+ clip_denoised=True,
576
+ denoised_fn=None,
577
+ cond_fn=None,
578
+ model_kwargs=None,
579
+ eta=0.0,
580
+ ):
581
+ """
582
+ Sample x_{t+1} from the model using DDIM reverse ODE.
583
+ """
584
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
585
+ out = self.p_mean_variance(
586
+ model,
587
+ x,
588
+ t,
589
+ clip_denoised=clip_denoised,
590
+ denoised_fn=denoised_fn,
591
+ model_kwargs=model_kwargs,
592
+ )
593
+ if cond_fn is not None:
594
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
595
+ # Usually our model outputs epsilon, but we re-derive it
596
+ # in case we used x_start or x_prev prediction.
597
+ eps = (
598
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
599
+ - out["pred_xstart"]
600
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
601
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
602
+
603
+ # Equation 12. reversed
604
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
605
+
606
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
607
+
608
+ def ddim_sample_loop(
609
+ self,
610
+ model,
611
+ shape,
612
+ noise=None,
613
+ clip_denoised=True,
614
+ denoised_fn=None,
615
+ cond_fn=None,
616
+ model_kwargs=None,
617
+ device=None,
618
+ progress=False,
619
+ eta=0.0,
620
+ ):
621
+ """
622
+ Generate samples from the model using DDIM.
623
+ Same usage as p_sample_loop().
624
+ """
625
+ final = None
626
+ for sample in self.ddim_sample_loop_progressive(
627
+ model,
628
+ shape,
629
+ noise=noise,
630
+ clip_denoised=clip_denoised,
631
+ denoised_fn=denoised_fn,
632
+ cond_fn=cond_fn,
633
+ model_kwargs=model_kwargs,
634
+ device=device,
635
+ progress=progress,
636
+ eta=eta,
637
+ ):
638
+ final = sample
639
+ return final["sample"]
640
+
641
+ def ddim_sample_loop_progressive(
642
+ self,
643
+ model,
644
+ shape,
645
+ noise=None,
646
+ clip_denoised=True,
647
+ denoised_fn=None,
648
+ cond_fn=None,
649
+ model_kwargs=None,
650
+ device=None,
651
+ progress=False,
652
+ eta=0.0,
653
+ ):
654
+ """
655
+ Use DDIM to sample from the model and yield intermediate samples from
656
+ each timestep of DDIM.
657
+ Same usage as p_sample_loop_progressive().
658
+ """
659
+ if device is None:
660
+ device = next(model.parameters()).device
661
+ assert isinstance(shape, (tuple, list))
662
+ if noise is not None:
663
+ img = noise
664
+ else:
665
+ img = th.randn(*shape, device=device)
666
+ indices = list(range(self.num_timesteps))[::-1]
667
+
668
+ if progress:
669
+ # Lazy import so that we don't depend on tqdm.
670
+ from tqdm.auto import tqdm
671
+
672
+ indices = tqdm(indices)
673
+
674
+ for i in indices:
675
+ t = th.tensor([i] * shape[0], device=device)
676
+ with th.no_grad():
677
+ out = self.ddim_sample(
678
+ model,
679
+ img,
680
+ t,
681
+ clip_denoised=clip_denoised,
682
+ denoised_fn=denoised_fn,
683
+ cond_fn=cond_fn,
684
+ model_kwargs=model_kwargs,
685
+ eta=eta,
686
+ )
687
+ yield out
688
+ img = out["sample"]
689
+
690
+ def _vb_terms_bpd(
691
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
692
+ ):
693
+ """
694
+ Get a term for the variational lower-bound.
695
+ The resulting units are bits (rather than nats, as one might expect).
696
+ This allows for comparison to other papers.
697
+ :return: a dict with the following keys:
698
+ - 'output': a shape [N] tensor of NLLs or KLs.
699
+ - 'pred_xstart': the x_0 predictions.
700
+ """
701
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
702
+ x_start=x_start, x_t=x_t, t=t
703
+ )
704
+ out = self.p_mean_variance(
705
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
706
+ )
707
+ kl = normal_kl(
708
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
709
+ )
710
+ kl = mean_flat(kl) / np.log(2.0)
711
+
712
+ decoder_nll = -discretized_gaussian_log_likelihood(
713
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
714
+ )
715
+ assert decoder_nll.shape == x_start.shape
716
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
717
+
718
+ # At the first timestep return the decoder NLL,
719
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
720
+ output = th.where((t == 0), decoder_nll, kl)
721
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
722
+
723
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
724
+ """
725
+ Compute training losses for a single timestep.
726
+ :param model: the model to evaluate loss on.
727
+ :param x_start: the [N x C x ...] tensor of inputs.
728
+ :param t: a batch of timestep indices.
729
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
730
+ pass to the model. This can be used for conditioning.
731
+ :param noise: if specified, the specific Gaussian noise to try to remove.
732
+ :return: a dict with the key "loss" containing a tensor of shape [N].
733
+ Some mean or variance settings may also have other keys.
734
+ """
735
+ if model_kwargs is None:
736
+ model_kwargs = {}
737
+ if noise is None:
738
+ noise = th.randn_like(x_start)
739
+ x_t = self.q_sample(x_start, t, noise=noise)
740
+
741
+ terms = {}
742
+
743
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
744
+ terms["loss"] = self._vb_terms_bpd(
745
+ model=model,
746
+ x_start=x_start,
747
+ x_t=x_t,
748
+ t=t,
749
+ clip_denoised=False,
750
+ model_kwargs=model_kwargs,
751
+ )["output"]
752
+ if self.loss_type == LossType.RESCALED_KL:
753
+ terms["loss"] *= self.num_timesteps
754
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
755
+ model_output = model(x_t, t, **model_kwargs)
756
+
757
+ if self.model_var_type in [
758
+ ModelVarType.LEARNED,
759
+ ModelVarType.LEARNED_RANGE,
760
+ ]:
761
+ B, C = x_t.shape[:2]
762
+ if len(model_output.shape) == len(x_t.shape) + 1:
763
+ x_t = x_t.unsqueeze(-1).expand(*([-1] * (len(x_t.shape))), model_output.shape[-1])
764
+ x_start = x_start.unsqueeze(-1).expand(*([-1] * (len(x_start.shape))), model_output.shape[-1])
765
+ assert model_output.shape == (B, C * 2, *x_t.shape[2:])
766
+ model_output, model_var_values = th.split(model_output, C, dim=1)
767
+ # Learn the variance using the variational bound, but don't let
768
+ # it affect our mean prediction.
769
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
770
+ terms["vb"] = self._vb_terms_bpd(
771
+ model=lambda *args, r=frozen_out: r,
772
+ x_start=x_start,
773
+ x_t=x_t,
774
+ t=t,
775
+ clip_denoised=False,
776
+ )["output"]
777
+ if self.loss_type == LossType.RESCALED_MSE:
778
+ # Divide by 1000 for equivalence with initial implementation.
779
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
780
+ terms["vb"] *= self.num_timesteps / 1000.0
781
+
782
+ target = {
783
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
784
+ x_start=x_start, x_t=x_t, t=t
785
+ )[0],
786
+ ModelMeanType.START_X: x_start,
787
+ ModelMeanType.EPSILON: noise,
788
+ }[self.model_mean_type]
789
+ if len(model_output.shape) == len(target.shape) + 1:
790
+ target = target.unsqueeze(-1).expand(*([-1] * (len(target.shape))), model_output.shape[-1])
791
+ assert model_output.shape == target.shape == x_start.shape
792
+ terms["mse"] = mean_flat((target - model_output) ** 2)
793
+ if "vb" in terms:
794
+ terms["loss"] = terms["mse"] + terms["vb"]
795
+ else:
796
+ terms["loss"] = terms["mse"]
797
+ else:
798
+ raise NotImplementedError(self.loss_type)
799
+
800
+ return terms
801
+
802
+ def _prior_bpd(self, x_start):
803
+ """
804
+ Get the prior KL term for the variational lower-bound, measured in
805
+ bits-per-dim.
806
+ This term can't be optimized, as it only depends on the encoder.
807
+ :param x_start: the [N x C x ...] tensor of inputs.
808
+ :return: a batch of [N] KL values (in bits), one per batch element.
809
+ """
810
+ batch_size = x_start.shape[0]
811
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
812
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
813
+ kl_prior = normal_kl(
814
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
815
+ )
816
+ return mean_flat(kl_prior) / np.log(2.0)
817
+
818
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
819
+ """
820
+ Compute the entire variational lower-bound, measured in bits-per-dim,
821
+ as well as other related quantities.
822
+ :param model: the model to evaluate loss on.
823
+ :param x_start: the [N x C x ...] tensor of inputs.
824
+ :param clip_denoised: if True, clip denoised samples.
825
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
826
+ pass to the model. This can be used for conditioning.
827
+ :return: a dict containing the following keys:
828
+ - total_bpd: the total variational lower-bound, per batch element.
829
+ - prior_bpd: the prior term in the lower-bound.
830
+ - vb: an [N x T] tensor of terms in the lower-bound.
831
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
832
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
833
+ """
834
+ device = x_start.device
835
+ batch_size = x_start.shape[0]
836
+
837
+ vb = []
838
+ xstart_mse = []
839
+ mse = []
840
+ for t in list(range(self.num_timesteps))[::-1]:
841
+ t_batch = th.tensor([t] * batch_size, device=device)
842
+ noise = th.randn_like(x_start)
843
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
844
+ # Calculate VLB term at the current timestep
845
+ with th.no_grad():
846
+ out = self._vb_terms_bpd(
847
+ model,
848
+ x_start=x_start,
849
+ x_t=x_t,
850
+ t=t_batch,
851
+ clip_denoised=clip_denoised,
852
+ model_kwargs=model_kwargs,
853
+ )
854
+ vb.append(out["output"])
855
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
856
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
857
+ mse.append(mean_flat((eps - noise) ** 2))
858
+
859
+ vb = th.stack(vb, dim=1)
860
+ xstart_mse = th.stack(xstart_mse, dim=1)
861
+ mse = th.stack(mse, dim=1)
862
+
863
+ prior_bpd = self._prior_bpd(x_start)
864
+ total_bpd = vb.sum(dim=1) + prior_bpd
865
+ return {
866
+ "total_bpd": total_bpd,
867
+ "prior_bpd": prior_bpd,
868
+ "vb": vb,
869
+ "xstart_mse": xstart_mse,
870
+ "mse": mse,
871
+ }
872
+
873
+
874
+ def _extract_into_tensor(arr, timesteps, broadcast_shape):
875
+ """
876
+ Extract values from a 1-D numpy array for a batch of indices.
877
+ :param arr: the 1-D numpy array.
878
+ :param timesteps: a tensor of indices into the array to extract.
879
+ :param broadcast_shape: a larger shape of K dimensions with the batch
880
+ dimension equal to the length of timesteps.
881
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
882
+ """
883
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
884
+ while len(res.shape) < len(broadcast_shape):
885
+ res = res[..., None]
886
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
semanticist/stage1/diffusion/respace.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ import numpy as np
7
+ import torch as th
8
+
9
+ from .gaussian_diffusion import GaussianDiffusion
10
+
11
+
12
+ def space_timesteps(num_timesteps, section_counts):
13
+ """
14
+ Create a list of timesteps to use from an original diffusion process,
15
+ given the number of timesteps we want to take from equally-sized portions
16
+ of the original process.
17
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
18
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
19
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
20
+ If the stride is a string starting with "ddim", then the fixed striding
21
+ from the DDIM paper is used, and only one section is allowed.
22
+ :param num_timesteps: the number of diffusion steps in the original
23
+ process to divide up.
24
+ :param section_counts: either a list of numbers, or a string containing
25
+ comma-separated numbers, indicating the step count
26
+ per section. As a special case, use "ddimN" where N
27
+ is a number of steps to use the striding from the
28
+ DDIM paper.
29
+ :return: a set of diffusion steps from the original process to use.
30
+ """
31
+ if isinstance(section_counts, str):
32
+ if section_counts.startswith("ddim"):
33
+ desired_count = int(section_counts[len("ddim") :])
34
+ for i in range(1, num_timesteps):
35
+ if len(range(0, num_timesteps, i)) == desired_count:
36
+ return set(range(0, num_timesteps, i))
37
+ raise ValueError(
38
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
39
+ )
40
+ section_counts = [int(x) for x in section_counts.split(",")]
41
+ size_per = num_timesteps // len(section_counts)
42
+ extra = num_timesteps % len(section_counts)
43
+ start_idx = 0
44
+ all_steps = []
45
+ for i, section_count in enumerate(section_counts):
46
+ size = size_per + (1 if i < extra else 0)
47
+ if size < section_count:
48
+ raise ValueError(
49
+ f"cannot divide section of {size} steps into {section_count}"
50
+ )
51
+ if section_count <= 1:
52
+ frac_stride = 1
53
+ else:
54
+ frac_stride = (size - 1) / (section_count - 1)
55
+ cur_idx = 0.0
56
+ taken_steps = []
57
+ for _ in range(section_count):
58
+ taken_steps.append(start_idx + round(cur_idx))
59
+ cur_idx += frac_stride
60
+ all_steps += taken_steps
61
+ start_idx += size
62
+ return set(all_steps)
63
+
64
+
65
+ class SpacedDiffusion(GaussianDiffusion):
66
+ """
67
+ A diffusion process which can skip steps in a base diffusion process.
68
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
69
+ original diffusion process to retain.
70
+ :param kwargs: the kwargs to create the base diffusion process.
71
+ """
72
+
73
+ def __init__(self, use_timesteps, **kwargs):
74
+ self.use_timesteps = set(use_timesteps)
75
+ self.timestep_map = []
76
+ self.original_num_steps = len(kwargs["betas"])
77
+
78
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
79
+ last_alpha_cumprod = 1.0
80
+ new_betas = []
81
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
82
+ if i in self.use_timesteps:
83
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
84
+ last_alpha_cumprod = alpha_cumprod
85
+ self.timestep_map.append(i)
86
+ kwargs["betas"] = np.array(new_betas)
87
+ super().__init__(**kwargs)
88
+
89
+ def p_mean_variance(
90
+ self, model, *args, **kwargs
91
+ ): # pylint: disable=signature-differs
92
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
93
+
94
+ def training_losses(
95
+ self, model, *args, **kwargs
96
+ ): # pylint: disable=signature-differs
97
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
98
+
99
+ def condition_mean(self, cond_fn, *args, **kwargs):
100
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
101
+
102
+ def condition_score(self, cond_fn, *args, **kwargs):
103
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
104
+
105
+ def _wrap_model(self, model):
106
+ if isinstance(model, _WrappedModel):
107
+ return model
108
+ return _WrappedModel(
109
+ model, self.timestep_map, self.original_num_steps
110
+ )
111
+
112
+ def _scale_timesteps(self, t):
113
+ # Scaling is done by the wrapped model.
114
+ return t
115
+
116
+
117
+ class _WrappedModel:
118
+ def __init__(self, model, timestep_map, original_num_steps):
119
+ self.model = model
120
+ self.timestep_map = timestep_map
121
+ # self.rescale_timesteps = rescale_timesteps
122
+ self.original_num_steps = original_num_steps
123
+
124
+ def __call__(self, x, ts, **kwargs):
125
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
126
+ new_ts = map_tensor[ts]
127
+ # if self.rescale_timesteps:
128
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
129
+ return self.model(x, new_ts, **kwargs)
130
+
semanticist/stage1/diffusion/timestep_sampler.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from OpenAI's diffusion repos
2
+ # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
+ # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
+ # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
+
6
+ from abc import ABC, abstractmethod
7
+
8
+ import numpy as np
9
+ import torch as th
10
+ import torch.distributed as dist
11
+
12
+
13
+ def create_named_schedule_sampler(name, diffusion):
14
+ """
15
+ Create a ScheduleSampler from a library of pre-defined samplers.
16
+ :param name: the name of the sampler.
17
+ :param diffusion: the diffusion object to sample for.
18
+ """
19
+ if name == "uniform":
20
+ return UniformSampler(diffusion)
21
+ elif name == "loss-second-moment":
22
+ return LossSecondMomentResampler(diffusion)
23
+ else:
24
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
25
+
26
+
27
+ class ScheduleSampler(ABC):
28
+ """
29
+ A distribution over timesteps in the diffusion process, intended to reduce
30
+ variance of the objective.
31
+ By default, samplers perform unbiased importance sampling, in which the
32
+ objective's mean is unchanged.
33
+ However, subclasses may override sample() to change how the resampled
34
+ terms are reweighted, allowing for actual changes in the objective.
35
+ """
36
+
37
+ @abstractmethod
38
+ def weights(self):
39
+ """
40
+ Get a numpy array of weights, one per diffusion step.
41
+ The weights needn't be normalized, but must be positive.
42
+ """
43
+
44
+ def sample(self, batch_size, device):
45
+ """
46
+ Importance-sample timesteps for a batch.
47
+ :param batch_size: the number of timesteps.
48
+ :param device: the torch device to save to.
49
+ :return: a tuple (timesteps, weights):
50
+ - timesteps: a tensor of timestep indices.
51
+ - weights: a tensor of weights to scale the resulting losses.
52
+ """
53
+ w = self.weights()
54
+ p = w / np.sum(w)
55
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56
+ indices = th.from_numpy(indices_np).long().to(device)
57
+ weights_np = 1 / (len(p) * p[indices_np])
58
+ weights = th.from_numpy(weights_np).float().to(device)
59
+ return indices, weights
60
+
61
+
62
+ class UniformSampler(ScheduleSampler):
63
+ def __init__(self, diffusion):
64
+ self.diffusion = diffusion
65
+ self._weights = np.ones([diffusion.num_timesteps])
66
+
67
+ def weights(self):
68
+ return self._weights
69
+
70
+
71
+ class LossAwareSampler(ScheduleSampler):
72
+ def update_with_local_losses(self, local_ts, local_losses):
73
+ """
74
+ Update the reweighting using losses from a model.
75
+ Call this method from each rank with a batch of timesteps and the
76
+ corresponding losses for each of those timesteps.
77
+ This method will perform synchronization to make sure all of the ranks
78
+ maintain the exact same reweighting.
79
+ :param local_ts: an integer Tensor of timesteps.
80
+ :param local_losses: a 1D Tensor of losses.
81
+ """
82
+ batch_sizes = [
83
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
84
+ for _ in range(dist.get_world_size())
85
+ ]
86
+ dist.all_gather(
87
+ batch_sizes,
88
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
89
+ )
90
+
91
+ # Pad all_gather batches to be the maximum batch size.
92
+ batch_sizes = [x.item() for x in batch_sizes]
93
+ max_bs = max(batch_sizes)
94
+
95
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
96
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
97
+ dist.all_gather(timestep_batches, local_ts)
98
+ dist.all_gather(loss_batches, local_losses)
99
+ timesteps = [
100
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
101
+ ]
102
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
103
+ self.update_with_all_losses(timesteps, losses)
104
+
105
+ @abstractmethod
106
+ def update_with_all_losses(self, ts, losses):
107
+ """
108
+ Update the reweighting using losses from a model.
109
+ Sub-classes should override this method to update the reweighting
110
+ using losses from the model.
111
+ This method directly updates the reweighting without synchronizing
112
+ between workers. It is called by update_with_local_losses from all
113
+ ranks with identical arguments. Thus, it should have deterministic
114
+ behavior to maintain state across workers.
115
+ :param ts: a list of int timesteps.
116
+ :param losses: a list of float losses, one per timestep.
117
+ """
118
+
119
+
120
+ class LossSecondMomentResampler(LossAwareSampler):
121
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
122
+ self.diffusion = diffusion
123
+ self.history_per_term = history_per_term
124
+ self.uniform_prob = uniform_prob
125
+ self._loss_history = np.zeros(
126
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
127
+ )
128
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
129
+
130
+ def weights(self):
131
+ if not self._warmed_up():
132
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
133
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
134
+ weights /= np.sum(weights)
135
+ weights *= 1 - self.uniform_prob
136
+ weights += self.uniform_prob / len(weights)
137
+ return weights
138
+
139
+ def update_with_all_losses(self, ts, losses):
140
+ for t, loss in zip(ts, losses):
141
+ if self._loss_counts[t] == self.history_per_term:
142
+ # Shift out the oldest loss term.
143
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
144
+ self._loss_history[t, -1] = loss
145
+ else:
146
+ self._loss_history[t, self._loss_counts[t]] = loss
147
+ self._loss_counts[t] += 1
148
+
149
+ def _warmed_up(self):
150
+ return (self._loss_counts == self.history_per_term).all()
semanticist/stage1/diffusion_transfomer.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # GLIDE: https://github.com/openai/glide-text2im
9
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
+ # --------------------------------------------------------
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import numpy as np
15
+ import math
16
+ from timm.models.vision_transformer import PatchEmbed, Mlp
17
+ from semanticist.stage1.fused_attention import Attention
18
+
19
+
20
+
21
+ def modulate(x, shift, scale):
22
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
23
+
24
+
25
+ #################################################################################
26
+ # Embedding Layers for Timesteps and Class Labels #
27
+ #################################################################################
28
+
29
+ class TimestepEmbedder(nn.Module):
30
+ """
31
+ Embeds scalar timesteps into vector representations.
32
+ """
33
+ def __init__(self, hidden_size, frequency_embedding_size=256):
34
+ super().__init__()
35
+ self.mlp = nn.Sequential(
36
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
37
+ nn.SiLU(),
38
+ nn.Linear(hidden_size, hidden_size, bias=True),
39
+ )
40
+ self.frequency_embedding_size = frequency_embedding_size
41
+
42
+ @staticmethod
43
+ def timestep_embedding(t, dim, max_period=10000):
44
+ """
45
+ Create sinusoidal timestep embeddings.
46
+ :param t: a 1-D Tensor of N indices, one per batch element.
47
+ These may be fractional.
48
+ :param dim: the dimension of the output.
49
+ :param max_period: controls the minimum frequency of the embeddings.
50
+ :return: an (N, D) Tensor of positional embeddings.
51
+ """
52
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
53
+ half = dim // 2
54
+ freqs = torch.exp(
55
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
56
+ ).to(device=t.device)
57
+ args = t[:, None].float() * freqs[None]
58
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
59
+ if dim % 2:
60
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
61
+ return embedding
62
+
63
+ def forward(self, t):
64
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
65
+ t_emb = self.mlp(t_freq.to(self.mlp[0].weight.dtype))
66
+ return t_emb
67
+
68
+
69
+ class LabelEmbedder(nn.Module):
70
+ """
71
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
72
+ """
73
+ def __init__(self, num_classes, hidden_size, dropout_prob):
74
+ super().__init__()
75
+ use_cfg_embedding = dropout_prob > 0
76
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
77
+ self.num_classes = num_classes
78
+ self.dropout_prob = dropout_prob
79
+
80
+ def token_drop(self, labels, force_drop_ids=None):
81
+ """
82
+ Drops labels to enable classifier-free guidance.
83
+ """
84
+ if force_drop_ids is None:
85
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
86
+ else:
87
+ drop_ids = force_drop_ids == 1
88
+ labels = torch.where(drop_ids, self.num_classes, labels)
89
+ return labels
90
+
91
+ def forward(self, labels, train, force_drop_ids=None):
92
+ use_dropout = self.dropout_prob > 0
93
+ if (train and use_dropout) or (force_drop_ids is not None):
94
+ labels = self.token_drop(labels, force_drop_ids)
95
+ embeddings = self.embedding_table(labels)
96
+ return embeddings
97
+
98
+
99
+ #################################################################################
100
+ # Core DiT Model #
101
+ #################################################################################
102
+
103
+ class DiTBlock(nn.Module):
104
+ """
105
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
106
+ """
107
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
108
+ super().__init__()
109
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
110
+ self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, **block_kwargs)
111
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
112
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
113
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
114
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0)
115
+ self.adaLN_modulation = nn.Sequential(
116
+ nn.SiLU(),
117
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True)
118
+ )
119
+
120
+ def forward(self, x, c, mask=None):
121
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).chunk(6, dim=1)
122
+ x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), mask)
123
+ x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp))
124
+ return x
125
+
126
+
127
+ class FinalLayer(nn.Module):
128
+ """
129
+ The final layer of DiT.
130
+ """
131
+ def __init__(self, hidden_size, patch_size, out_channels):
132
+ super().__init__()
133
+ self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
134
+ self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True)
135
+ self.adaLN_modulation = nn.Sequential(
136
+ nn.SiLU(),
137
+ nn.Linear(hidden_size, 2 * hidden_size, bias=True)
138
+ )
139
+
140
+ def forward(self, x, c):
141
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
142
+ x = modulate(self.norm_final(x), shift, scale)
143
+ x = self.linear(x)
144
+ return x
145
+
146
+
147
+ class DiT(nn.Module):
148
+ """
149
+ Diffusion model with a Transformer backbone.
150
+ """
151
+ def __init__(
152
+ self,
153
+ input_size=32,
154
+ patch_size=2,
155
+ in_channels=4,
156
+ hidden_size=1152,
157
+ depth=28,
158
+ num_heads=16,
159
+ mlp_ratio=4.0,
160
+ class_dropout_prob=0.1,
161
+ num_classes=1000,
162
+ learn_sigma=True,
163
+ ):
164
+ super().__init__()
165
+ self.learn_sigma = learn_sigma
166
+ self.in_channels = in_channels
167
+ self.out_channels = in_channels * 2 if learn_sigma else in_channels
168
+ self.patch_size = patch_size
169
+ self.num_heads = num_heads
170
+
171
+ self.x_embedder = PatchEmbed(input_size, patch_size, in_channels, hidden_size, bias=True)
172
+ self.t_embedder = TimestepEmbedder(hidden_size)
173
+ self.y_embedder = LabelEmbedder(num_classes, hidden_size, class_dropout_prob)
174
+ num_patches = self.x_embedder.num_patches
175
+ # Will use fixed sin-cos embedding:
176
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, hidden_size), requires_grad=False)
177
+
178
+ self.blocks = nn.ModuleList([
179
+ DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)
180
+ ])
181
+ self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels)
182
+ self.initialize_weights()
183
+
184
+ def initialize_weights(self):
185
+ # Initialize transformer layers:
186
+ def _basic_init(module):
187
+ if isinstance(module, nn.Linear):
188
+ torch.nn.init.xavier_uniform_(module.weight)
189
+ if module.bias is not None:
190
+ nn.init.constant_(module.bias, 0)
191
+ self.apply(_basic_init)
192
+
193
+ # Initialize (and freeze) pos_embed by sin-cos embedding:
194
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.x_embedder.num_patches ** 0.5))
195
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
196
+
197
+ # Initialize patch_embed like nn.Linear (instead of nn.Conv2d):
198
+ w = self.x_embedder.proj.weight.data
199
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
200
+ nn.init.constant_(self.x_embedder.proj.bias, 0)
201
+
202
+ # Initialize label embedding table:
203
+ nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02)
204
+
205
+ # Initialize timestep embedding MLP:
206
+ nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02)
207
+ nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02)
208
+
209
+ # Zero-out adaLN modulation layers in DiT blocks:
210
+ for block in self.blocks:
211
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
212
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
213
+
214
+ # Zero-out output layers:
215
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
216
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
217
+ nn.init.constant_(self.final_layer.linear.weight, 0)
218
+ nn.init.constant_(self.final_layer.linear.bias, 0)
219
+
220
+ def unpatchify(self, x):
221
+ """
222
+ x: (N, T, patch_size**2 * C)
223
+ imgs: (N, H, W, C)
224
+ """
225
+ c = self.out_channels
226
+ p = self.x_embedder.patch_size[0]
227
+ h = w = int(x.shape[1] ** 0.5)
228
+ assert h * w == x.shape[1]
229
+
230
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, c))
231
+ x = torch.einsum('nhwpqc->nchpwq', x)
232
+ imgs = x.reshape(shape=(x.shape[0], c, h * p, h * p))
233
+ return imgs
234
+
235
+ def forward(self, x, t, y):
236
+ """
237
+ Forward pass of DiT.
238
+ x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images)
239
+ t: (N,) tensor of diffusion timesteps
240
+ y: (N,) tensor of class labels
241
+ """
242
+ x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T = H * W / patch_size ** 2
243
+ t = self.t_embedder(t) # (N, D)
244
+ y = self.y_embedder(y, self.training) # (N, D)
245
+ c = t + y # (N, D)
246
+ for block in self.blocks:
247
+ x = block(x, c) # (N, T, D)
248
+ x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels)
249
+ x = self.unpatchify(x) # (N, out_channels, H, W)
250
+ return x
251
+
252
+ def forward_with_cfg(self, x, t, y, cfg_scale):
253
+ """
254
+ Forward pass of DiT, but also batches the unconditional forward pass for classifier-free guidance.
255
+ """
256
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
257
+ half = x[: len(x) // 2]
258
+ combined = torch.cat([half, half], dim=0)
259
+ model_out = self.forward(combined, t, y)
260
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
261
+ # three channels by default. The standard approach to cfg applies it to all channels.
262
+ # This can be done by uncommenting the following line and commenting-out the line following that.
263
+ eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
264
+ # eps, rest = model_out[:, :3], model_out[:, 3:]
265
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
266
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
267
+ eps = torch.cat([half_eps, half_eps], dim=0)
268
+ return torch.cat([eps, rest], dim=1)
269
+
270
+
271
+ #################################################################################
272
+ # Sine/Cosine Positional Embedding Functions #
273
+ #################################################################################
274
+ # https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py
275
+
276
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0):
277
+ """
278
+ grid_size: int of the grid height and width
279
+ return:
280
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
281
+ """
282
+ grid_h = np.arange(grid_size, dtype=np.float32)
283
+ grid_w = np.arange(grid_size, dtype=np.float32)
284
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
285
+ grid = np.stack(grid, axis=0)
286
+
287
+ grid = grid.reshape([2, 1, grid_size, grid_size])
288
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
289
+ if cls_token and extra_tokens > 0:
290
+ pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0)
291
+ return pos_embed
292
+
293
+
294
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
295
+ assert embed_dim % 2 == 0
296
+
297
+ # use half of dimensions to encode grid_h
298
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
299
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
300
+
301
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
302
+ return emb
303
+
304
+
305
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
306
+ """
307
+ embed_dim: output dimension for each position
308
+ pos: a list of positions to be encoded: size (M,)
309
+ out: (M, D)
310
+ """
311
+ assert embed_dim % 2 == 0
312
+ omega = np.arange(embed_dim // 2, dtype=np.float64)
313
+ omega /= embed_dim / 2.
314
+ omega = 1. / 10000**omega # (D/2,)
315
+
316
+ pos = pos.reshape(-1) # (M,)
317
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
318
+
319
+ emb_sin = np.sin(out) # (M, D/2)
320
+ emb_cos = np.cos(out) # (M, D/2)
321
+
322
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
323
+ return emb
324
+
325
+
326
+ #################################################################################
327
+ # DiT Configs #
328
+ #################################################################################
329
+
330
+ def DiT_XL_2(**kwargs):
331
+ return DiT(depth=28, hidden_size=1152, patch_size=2, num_heads=16, **kwargs)
332
+
333
+ def DiT_XL_4(**kwargs):
334
+ return DiT(depth=28, hidden_size=1152, patch_size=4, num_heads=16, **kwargs)
335
+
336
+ def DiT_XL_8(**kwargs):
337
+ return DiT(depth=28, hidden_size=1152, patch_size=8, num_heads=16, **kwargs)
338
+
339
+ def DiT_L_2(**kwargs):
340
+ return DiT(depth=24, hidden_size=1024, patch_size=2, num_heads=16, **kwargs)
341
+
342
+ def DiT_L_4(**kwargs):
343
+ return DiT(depth=24, hidden_size=1024, patch_size=4, num_heads=16, **kwargs)
344
+
345
+ def DiT_L_8(**kwargs):
346
+ return DiT(depth=24, hidden_size=1024, patch_size=8, num_heads=16, **kwargs)
347
+
348
+ def DiT_B_2(**kwargs):
349
+ return DiT(depth=12, hidden_size=768, patch_size=2, num_heads=12, **kwargs)
350
+
351
+ def DiT_B_4(**kwargs):
352
+ return DiT(depth=12, hidden_size=768, patch_size=4, num_heads=12, **kwargs)
353
+
354
+ def DiT_B_8(**kwargs):
355
+ return DiT(depth=12, hidden_size=768, patch_size=8, num_heads=12, **kwargs)
356
+
357
+ def DiT_S_2(**kwargs):
358
+ return DiT(depth=12, hidden_size=384, patch_size=2, num_heads=6, **kwargs)
359
+
360
+ def DiT_S_4(**kwargs):
361
+ return DiT(depth=12, hidden_size=384, patch_size=4, num_heads=6, **kwargs)
362
+
363
+ def DiT_S_8(**kwargs):
364
+ return DiT(depth=12, hidden_size=384, patch_size=8, num_heads=6, **kwargs)
365
+
366
+
367
+ DiT_models = {
368
+ 'DiT-XL/2': DiT_XL_2, 'DiT-XL/4': DiT_XL_4, 'DiT-XL/8': DiT_XL_8,
369
+ 'DiT-L/2': DiT_L_2, 'DiT-L/4': DiT_L_4, 'DiT-L/8': DiT_L_8,
370
+ 'DiT-B/2': DiT_B_2, 'DiT-B/4': DiT_B_4, 'DiT-B/8': DiT_B_8,
371
+ 'DiT-S/2': DiT_S_2, 'DiT-S/4': DiT_S_4, 'DiT-S/8': DiT_S_8,
372
+ }
semanticist/stage1/fused_attention.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ from typing import Type
4
+
5
+ class Attention(nn.Module):
6
+ def __init__(
7
+ self,
8
+ dim: int,
9
+ num_heads: int = 8,
10
+ qkv_bias: bool = False,
11
+ qk_norm: bool = False,
12
+ proj_bias: bool = True,
13
+ attn_drop: float = 0.,
14
+ proj_drop: float = 0.,
15
+ norm_layer: Type[nn.Module] = nn.LayerNorm,
16
+ ) -> None:
17
+ super().__init__()
18
+ assert dim % num_heads == 0, 'dim should be divisible by num_heads'
19
+ self.num_heads = num_heads
20
+ self.head_dim = dim // num_heads
21
+ self.scale = self.head_dim ** -0.5
22
+
23
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
24
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
25
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
26
+ self.attn_drop = nn.Dropout(attn_drop)
27
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
28
+ self.proj_drop = nn.Dropout(proj_drop)
29
+
30
+ def forward(self, x, attn_mask=None):
31
+ B, N, C = x.shape
32
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
33
+ q, k, v = qkv.unbind(0)
34
+ q, k = self.q_norm(q), self.k_norm(k)
35
+
36
+ x = F.scaled_dot_product_attention(
37
+ q, k, v,
38
+ attn_mask=attn_mask,
39
+ dropout_p=self.attn_drop.p if self.training else 0.,
40
+ )
41
+
42
+ x = x.transpose(1, 2).reshape(B, N, C)
43
+ x = self.proj(x)
44
+ x = self.proj_drop(x)
45
+ return x
semanticist/stage1/pos_embed.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # Position embedding utils
8
+ # --------------------------------------------------------
9
+
10
+ import numpy as np
11
+
12
+ import torch
13
+
14
+ # --------------------------------------------------------
15
+ # 2D sine-cosine position embedding
16
+ # References:
17
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
18
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
19
+ # --------------------------------------------------------
20
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
21
+ """
22
+ grid_size: int of the grid height and width
23
+ return:
24
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
25
+ """
26
+ grid_h = np.arange(grid_size, dtype=np.float32)
27
+ grid_w = np.arange(grid_size, dtype=np.float32)
28
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
29
+ grid = np.stack(grid, axis=0)
30
+
31
+ grid = grid.reshape([2, 1, grid_size, grid_size])
32
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
33
+ if cls_token:
34
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
35
+ return pos_embed
36
+
37
+
38
+ def get_1d_sincos_pos_embed(embed_dim, grid_size):
39
+ grid = np.arange(grid_size, dtype=np.float32)
40
+ pos_embed = get_1d_sincos_pos_embed_from_grid(embed_dim, grid)
41
+ return pos_embed
42
+
43
+
44
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
45
+ assert embed_dim % 2 == 0
46
+
47
+ # use half of dimensions to encode grid_h
48
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
49
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
50
+
51
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
52
+ return emb
53
+
54
+
55
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
56
+ """
57
+ embed_dim: output dimension for each position
58
+ pos: a list of positions to be encoded: size (M,)
59
+ out: (M, D)
60
+ """
61
+ assert embed_dim % 2 == 0
62
+ omega = np.arange(embed_dim // 2, dtype=float)
63
+ omega /= embed_dim / 2.
64
+ omega = 1. / 10000**omega # (D/2,)
65
+
66
+ pos = pos.reshape(-1) # (M,)
67
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
68
+
69
+ emb_sin = np.sin(out) # (M, D/2)
70
+ emb_cos = np.cos(out) # (M, D/2)
71
+
72
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
73
+ return emb
74
+
75
+
76
+ # --------------------------------------------------------
77
+ # Interpolate position embeddings for high-resolution
78
+ # References:
79
+ # DeiT: https://github.com/facebookresearch/deit
80
+ # --------------------------------------------------------
81
+ def interpolate_pos_embed(model, checkpoint_model):
82
+ if 'pos_embed' in checkpoint_model:
83
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
84
+ embedding_size = pos_embed_checkpoint.shape[-1]
85
+ num_patches = model.patch_embed.num_patches
86
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
87
+ # height (== width) for the checkpoint position embedding
88
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
89
+ # height (== width) for the new position embedding
90
+ new_size = int(num_patches ** 0.5)
91
+ # class_token and dist_token are kept unchanged
92
+ if orig_size != new_size:
93
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
94
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
95
+ # only the position tokens are interpolated
96
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
97
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
98
+ pos_tokens = torch.nn.functional.interpolate(
99
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
100
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
101
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
102
+ checkpoint_model['pos_embed'] = new_pos_embed
semanticist/stage1/transport/__init__.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .transport import Transport, ModelType, WeightType, PathType, Sampler
2
+
3
+ def create_transport(
4
+ path_type='Linear',
5
+ prediction="velocity",
6
+ loss_weight=None,
7
+ train_eps=None,
8
+ sample_eps=None,
9
+ ):
10
+ """function for creating Transport object
11
+ **Note**: model prediction defaults to velocity
12
+ Args:
13
+ - path_type: type of path to use; default to linear
14
+ - learn_score: set model prediction to score
15
+ - learn_noise: set model prediction to noise
16
+ - velocity_weighted: weight loss by velocity weight
17
+ - likelihood_weighted: weight loss by likelihood weight
18
+ - train_eps: small epsilon for avoiding instability during training
19
+ - sample_eps: small epsilon for avoiding instability during sampling
20
+ """
21
+
22
+ if prediction == "noise":
23
+ model_type = ModelType.NOISE
24
+ elif prediction == "score":
25
+ model_type = ModelType.SCORE
26
+ else:
27
+ model_type = ModelType.VELOCITY
28
+
29
+ if loss_weight == "velocity":
30
+ loss_type = WeightType.VELOCITY
31
+ elif loss_weight == "likelihood":
32
+ loss_type = WeightType.LIKELIHOOD
33
+ else:
34
+ loss_type = WeightType.NONE
35
+
36
+ path_choice = {
37
+ "Linear": PathType.LINEAR,
38
+ "GVP": PathType.GVP,
39
+ "VP": PathType.VP,
40
+ }
41
+
42
+ path_type = path_choice[path_type]
43
+
44
+ if (path_type in [PathType.VP]):
45
+ train_eps = 1e-5 if train_eps is None else train_eps
46
+ sample_eps = 1e-3 if train_eps is None else sample_eps
47
+ elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY):
48
+ train_eps = 1e-3 if train_eps is None else train_eps
49
+ sample_eps = 1e-3 if train_eps is None else sample_eps
50
+ else: # velocity & [GVP, LINEAR] is stable everywhere
51
+ train_eps = 0
52
+ sample_eps = 0
53
+
54
+ # create flow state
55
+ state = Transport(
56
+ model_type=model_type,
57
+ path_type=path_type,
58
+ loss_type=loss_type,
59
+ train_eps=train_eps,
60
+ sample_eps=sample_eps,
61
+ )
62
+
63
+ return state
semanticist/stage1/transport/integrators.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch as th
3
+ import torch.nn as nn
4
+ from torchdiffeq import odeint
5
+ from functools import partial
6
+ from tqdm import tqdm
7
+
8
+ class sde:
9
+ """SDE solver class"""
10
+ def __init__(
11
+ self,
12
+ drift,
13
+ diffusion,
14
+ *,
15
+ t0,
16
+ t1,
17
+ num_steps,
18
+ sampler_type,
19
+ temperature=1.0,
20
+ ):
21
+ assert t0 < t1, "SDE sampler has to be in forward time"
22
+
23
+ self.num_timesteps = num_steps
24
+ self.t = th.linspace(t0, t1, num_steps)
25
+ self.dt = self.t[1] - self.t[0]
26
+ self.drift = drift
27
+ self.diffusion = diffusion
28
+ self.sampler_type = sampler_type
29
+ self.temperature = temperature
30
+
31
+ def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
32
+ w_cur = th.randn(x.size()).to(x)
33
+ t = th.ones(x.size(0)).to(x) * t
34
+ dw = w_cur * th.sqrt(self.dt)
35
+ drift = self.drift(x, t, model, **model_kwargs)
36
+ diffusion = self.diffusion(x, t)
37
+ mean_x = x + drift * self.dt
38
+ x = mean_x + th.sqrt(2 * diffusion) * dw * self.temperature
39
+ return x, mean_x
40
+
41
+ def __Heun_step(self, x, _, t, model, **model_kwargs):
42
+ w_cur = th.randn(x.size()).to(x)
43
+ dw = w_cur * th.sqrt(self.dt) * self.temperature
44
+ t_cur = th.ones(x.size(0)).to(x) * t
45
+ diffusion = self.diffusion(x, t_cur)
46
+ xhat = x + th.sqrt(2 * diffusion) * dw
47
+ K1 = self.drift(xhat, t_cur, model, **model_kwargs)
48
+ xp = xhat + self.dt * K1
49
+ K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
50
+ return xhat + 0.5 * self.dt * (K1 + K2), xhat # at last time point we do not perform the heun step
51
+
52
+ def __forward_fn(self):
53
+ """TODO: generalize here by adding all private functions ending with steps to it"""
54
+ sampler_dict = {
55
+ "Euler": self.__Euler_Maruyama_step,
56
+ "Heun": self.__Heun_step,
57
+ }
58
+
59
+ try:
60
+ sampler = sampler_dict[self.sampler_type]
61
+ except:
62
+ raise NotImplementedError("Smapler type not implemented.")
63
+
64
+ return sampler
65
+
66
+ def sample(self, init, model, **model_kwargs):
67
+ """forward loop of sde"""
68
+ x = init
69
+ mean_x = init
70
+ samples = []
71
+ sampler = self.__forward_fn()
72
+ for ti in self.t[:-1]:
73
+ with th.no_grad():
74
+ x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
75
+ samples.append(x)
76
+
77
+ return samples
78
+
79
+ class ode:
80
+ """ODE solver class"""
81
+ def __init__(
82
+ self,
83
+ drift,
84
+ *,
85
+ t0,
86
+ t1,
87
+ sampler_type,
88
+ num_steps,
89
+ atol,
90
+ rtol,
91
+ temperature=1.0,
92
+ ):
93
+ assert t0 < t1, "ODE sampler has to be in forward time"
94
+
95
+ self.drift = drift
96
+ self.t = th.linspace(t0, t1, num_steps)
97
+ self.atol = atol
98
+ self.rtol = rtol
99
+ self.sampler_type = sampler_type
100
+ self.temperature = temperature
101
+
102
+ def sample(self, x, model, **model_kwargs):
103
+
104
+ device = x[0].device if isinstance(x, tuple) else x.device
105
+ def _fn(t, x):
106
+ t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
107
+ # For ODE, we scale the drift by the temperature
108
+ # This is equivalent to scaling time by 1/temperature
109
+ model_output = self.drift(x, t, model, **model_kwargs)
110
+ if self.temperature != 1.0:
111
+ # If it's a tuple (for likelihood calculation), only scale the first element
112
+ if isinstance(model_output, tuple):
113
+ scaled_output = (model_output[0] / self.temperature, model_output[1])
114
+ return scaled_output
115
+ else:
116
+ return model_output / self.temperature
117
+ return model_output
118
+
119
+ t = self.t.to(device)
120
+ atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
121
+ rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
122
+ samples = odeint(
123
+ _fn,
124
+ x,
125
+ t,
126
+ method=self.sampler_type,
127
+ atol=atol,
128
+ rtol=rtol
129
+ )
130
+ return samples
semanticist/stage1/transport/path.py ADDED
@@ -0,0 +1,192 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ import numpy as np
3
+ from functools import partial
4
+
5
+ def expand_t_like_x(t, x):
6
+ """Function to reshape time t to broadcastable dimension of x
7
+ Args:
8
+ t: [batch_dim,], time vector
9
+ x: [batch_dim,...], data point
10
+ """
11
+ dims = [1] * (len(x.size()) - 1)
12
+ t = t.view(t.size(0), *dims)
13
+ return t
14
+
15
+
16
+ #################### Coupling Plans ####################
17
+
18
+ class ICPlan:
19
+ """Linear Coupling Plan"""
20
+ def __init__(self, sigma=0.0):
21
+ self.sigma = sigma
22
+
23
+ def compute_alpha_t(self, t):
24
+ """Compute the data coefficient along the path"""
25
+ return t, 1
26
+
27
+ def compute_sigma_t(self, t):
28
+ """Compute the noise coefficient along the path"""
29
+ return 1 - t, -1
30
+
31
+ def compute_d_alpha_alpha_ratio_t(self, t):
32
+ """Compute the ratio between d_alpha and alpha"""
33
+ return 1 / t
34
+
35
+ def compute_drift(self, x, t):
36
+ """We always output sde according to score parametrization; """
37
+ t = expand_t_like_x(t, x)
38
+ alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
39
+ sigma_t, d_sigma_t = self.compute_sigma_t(t)
40
+ drift = alpha_ratio * x
41
+ diffusion = alpha_ratio * (sigma_t ** 2) - sigma_t * d_sigma_t
42
+
43
+ return -drift, diffusion
44
+
45
+ def compute_diffusion(self, x, t, form="constant", norm=1.0):
46
+ """Compute the diffusion term of the SDE
47
+ Args:
48
+ x: [batch_dim, ...], data point
49
+ t: [batch_dim,], time vector
50
+ form: str, form of the diffusion term
51
+ norm: float, norm of the diffusion term
52
+ """
53
+ t = expand_t_like_x(t, x)
54
+ choices = {
55
+ "constant": norm,
56
+ "SBDM": norm * self.compute_drift(x, t)[1],
57
+ "sigma": norm * self.compute_sigma_t(t)[0],
58
+ "linear": norm * (1 - t),
59
+ "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
60
+ "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
61
+ }
62
+
63
+ try:
64
+ diffusion = choices[form]
65
+ except KeyError:
66
+ raise NotImplementedError(f"Diffusion form {form} not implemented")
67
+
68
+ return diffusion
69
+
70
+ def get_score_from_velocity(self, velocity, x, t):
71
+ """Wrapper function: transfrom velocity prediction model to score
72
+ Args:
73
+ velocity: [batch_dim, ...] shaped tensor; velocity model output
74
+ x: [batch_dim, ...] shaped tensor; x_t data point
75
+ t: [batch_dim,] time tensor
76
+ """
77
+ t = expand_t_like_x(t, x)
78
+ alpha_t, d_alpha_t = self.compute_alpha_t(t)
79
+ sigma_t, d_sigma_t = self.compute_sigma_t(t)
80
+ mean = x
81
+ reverse_alpha_ratio = alpha_t / d_alpha_t
82
+ var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
83
+ score = (reverse_alpha_ratio * velocity - mean) / var
84
+ return score
85
+
86
+ def get_noise_from_velocity(self, velocity, x, t):
87
+ """Wrapper function: transfrom velocity prediction model to denoiser
88
+ Args:
89
+ velocity: [batch_dim, ...] shaped tensor; velocity model output
90
+ x: [batch_dim, ...] shaped tensor; x_t data point
91
+ t: [batch_dim,] time tensor
92
+ """
93
+ t = expand_t_like_x(t, x)
94
+ alpha_t, d_alpha_t = self.compute_alpha_t(t)
95
+ sigma_t, d_sigma_t = self.compute_sigma_t(t)
96
+ mean = x
97
+ reverse_alpha_ratio = alpha_t / d_alpha_t
98
+ var = reverse_alpha_ratio * d_sigma_t - sigma_t
99
+ noise = (reverse_alpha_ratio * velocity - mean) / var
100
+ return noise
101
+
102
+ def get_velocity_from_score(self, score, x, t):
103
+ """Wrapper function: transfrom score prediction model to velocity
104
+ Args:
105
+ score: [batch_dim, ...] shaped tensor; score model output
106
+ x: [batch_dim, ...] shaped tensor; x_t data point
107
+ t: [batch_dim,] time tensor
108
+ """
109
+ t = expand_t_like_x(t, x)
110
+ drift, var = self.compute_drift(x, t)
111
+ velocity = var * score - drift
112
+ return velocity
113
+
114
+ def compute_mu_t(self, t, x0, x1):
115
+ """Compute the mean of time-dependent density p_t"""
116
+ t = expand_t_like_x(t, x1)
117
+ alpha_t, _ = self.compute_alpha_t(t)
118
+ sigma_t, _ = self.compute_sigma_t(t)
119
+ return alpha_t * x1 + sigma_t * x0
120
+
121
+ def compute_xt(self, t, x0, x1):
122
+ """Sample xt from time-dependent density p_t; rng is required"""
123
+ xt = self.compute_mu_t(t, x0, x1)
124
+ return xt
125
+
126
+ def compute_ut(self, t, x0, x1, xt):
127
+ """Compute the vector field corresponding to p_t"""
128
+ t = expand_t_like_x(t, x1)
129
+ _, d_alpha_t = self.compute_alpha_t(t)
130
+ _, d_sigma_t = self.compute_sigma_t(t)
131
+ return d_alpha_t * x1 + d_sigma_t * x0
132
+
133
+ def plan(self, t, x0, x1):
134
+ xt = self.compute_xt(t, x0, x1)
135
+ ut = self.compute_ut(t, x0, x1, xt)
136
+ return t, xt, ut
137
+
138
+
139
+ class VPCPlan(ICPlan):
140
+ """class for VP path flow matching"""
141
+
142
+ def __init__(self, sigma_min=0.1, sigma_max=20.0):
143
+ self.sigma_min = sigma_min
144
+ self.sigma_max = sigma_max
145
+ self.log_mean_coeff = lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
146
+ self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
147
+
148
+
149
+ def compute_alpha_t(self, t):
150
+ """Compute coefficient of x1"""
151
+ alpha_t = self.log_mean_coeff(t)
152
+ alpha_t = th.exp(alpha_t)
153
+ d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
154
+ return alpha_t, d_alpha_t
155
+
156
+ def compute_sigma_t(self, t):
157
+ """Compute coefficient of x0"""
158
+ p_sigma_t = 2 * self.log_mean_coeff(t)
159
+ sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
160
+ d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
161
+ return sigma_t, d_sigma_t
162
+
163
+ def compute_d_alpha_alpha_ratio_t(self, t):
164
+ """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
165
+ return self.d_log_mean_coeff(t)
166
+
167
+ def compute_drift(self, x, t):
168
+ """Compute the drift term of the SDE"""
169
+ t = expand_t_like_x(t, x)
170
+ beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
171
+ return -0.5 * beta_t * x, beta_t / 2
172
+
173
+
174
+ class GVPCPlan(ICPlan):
175
+ def __init__(self, sigma=0.0):
176
+ super().__init__(sigma)
177
+
178
+ def compute_alpha_t(self, t):
179
+ """Compute coefficient of x1"""
180
+ alpha_t = th.sin(t * np.pi / 2)
181
+ d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
182
+ return alpha_t, d_alpha_t
183
+
184
+ def compute_sigma_t(self, t):
185
+ """Compute coefficient of x0"""
186
+ sigma_t = th.cos(t * np.pi / 2)
187
+ d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
188
+ return sigma_t, d_sigma_t
189
+
190
+ def compute_d_alpha_alpha_ratio_t(self, t):
191
+ """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
192
+ return np.pi / (2 * th.tan(t * np.pi / 2))
semanticist/stage1/transport/transport.py ADDED
@@ -0,0 +1,456 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+ import numpy as np
3
+ import logging
4
+
5
+ import enum
6
+
7
+ from . import path
8
+ from .utils import EasyDict, log_state, mean_flat
9
+ from .integrators import ode, sde
10
+
11
+ class ModelType(enum.Enum):
12
+ """
13
+ Which type of output the model predicts.
14
+ """
15
+
16
+ NOISE = enum.auto() # the model predicts epsilon
17
+ SCORE = enum.auto() # the model predicts \nabla \log p(x)
18
+ VELOCITY = enum.auto() # the model predicts v(x)
19
+
20
+ class PathType(enum.Enum):
21
+ """
22
+ Which type of path to use.
23
+ """
24
+
25
+ LINEAR = enum.auto()
26
+ GVP = enum.auto()
27
+ VP = enum.auto()
28
+
29
+ class WeightType(enum.Enum):
30
+ """
31
+ Which type of weighting to use.
32
+ """
33
+
34
+ NONE = enum.auto()
35
+ VELOCITY = enum.auto()
36
+ LIKELIHOOD = enum.auto()
37
+
38
+
39
+ class Transport:
40
+
41
+ def __init__(
42
+ self,
43
+ *,
44
+ model_type,
45
+ path_type,
46
+ loss_type,
47
+ train_eps,
48
+ sample_eps,
49
+ ):
50
+ path_options = {
51
+ PathType.LINEAR: path.ICPlan,
52
+ PathType.GVP: path.GVPCPlan,
53
+ PathType.VP: path.VPCPlan,
54
+ }
55
+
56
+ self.loss_type = loss_type
57
+ self.model_type = model_type
58
+ self.path_sampler = path_options[path_type]()
59
+ self.train_eps = train_eps
60
+ self.sample_eps = sample_eps
61
+
62
+ def prior_logp(self, z):
63
+ '''
64
+ Standard multivariate normal prior
65
+ Assume z is batched
66
+ '''
67
+ shape = th.tensor(z.size())
68
+ N = th.prod(shape[1:])
69
+ _fn = lambda x: -N / 2. * np.log(2 * np.pi) - th.sum(x ** 2) / 2.
70
+ return th.vmap(_fn)(z)
71
+
72
+
73
+ def check_interval(
74
+ self,
75
+ train_eps,
76
+ sample_eps,
77
+ *,
78
+ diffusion_form="SBDM",
79
+ sde=False,
80
+ reverse=False,
81
+ eval=False,
82
+ last_step_size=0.0,
83
+ ):
84
+ t0 = 0
85
+ t1 = 1
86
+ eps = train_eps if not eval else sample_eps
87
+ if (type(self.path_sampler) in [path.VPCPlan]):
88
+
89
+ t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
90
+
91
+ elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) \
92
+ and (self.model_type != ModelType.VELOCITY or sde): # avoid numerical issue by taking a first semi-implicit step
93
+
94
+ t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
95
+ t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
96
+
97
+ if reverse:
98
+ t0, t1 = 1 - t0, 1 - t1
99
+
100
+ return t0, t1
101
+
102
+
103
+ def sample(self, x1):
104
+ """Sampling x0 & t based on shape of x1 (if needed)
105
+ Args:
106
+ x1 - data point; [batch, *dim]
107
+ """
108
+
109
+ x0 = th.randn_like(x1)
110
+ t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
111
+ t = th.rand((x1.shape[0],)) * (t1 - t0) + t0
112
+ t = t.to(x1)
113
+ return t, x0, x1
114
+
115
+
116
+ def training_losses(
117
+ self,
118
+ model,
119
+ x1,
120
+ model_kwargs=None
121
+ ):
122
+ """Loss for training the score model
123
+ Args:
124
+ - model: backbone model; could be score, noise, or velocity
125
+ - x1: datapoint
126
+ - model_kwargs: additional arguments for the model
127
+ """
128
+ if model_kwargs == None:
129
+ model_kwargs = {}
130
+
131
+ t, x0, x1 = self.sample(x1)
132
+ t, xt, ut = self.path_sampler.plan(t, x0, x1)
133
+ model_output = model(xt, t, **model_kwargs)
134
+ if len(model_output.shape) == len(xt.shape) + 1:
135
+ x0 = x0.unsqueeze(-1).expand(*([-1] * (len(x0.shape))), model_output.shape[-1])
136
+ xt = xt.unsqueeze(-1).expand(*([-1] * (len(xt.shape))), model_output.shape[-1])
137
+ ut = ut.unsqueeze(-1).expand(*([-1] * (len(ut.shape))), model_output.shape[-1])
138
+ B, C = xt.shape[:2]
139
+ assert model_output.shape == (B, C, *xt.shape[2:])
140
+
141
+ terms = {}
142
+ terms['pred'] = model_output
143
+ if self.model_type == ModelType.VELOCITY:
144
+ terms['loss'] = mean_flat(((model_output - ut) ** 2))
145
+ else:
146
+ _, drift_var = self.path_sampler.compute_drift(xt, t)
147
+ sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, xt))
148
+ if self.loss_type in [WeightType.VELOCITY]:
149
+ weight = (drift_var / sigma_t) ** 2
150
+ elif self.loss_type in [WeightType.LIKELIHOOD]:
151
+ weight = drift_var / (sigma_t ** 2)
152
+ elif self.loss_type in [WeightType.NONE]:
153
+ weight = 1
154
+ else:
155
+ raise NotImplementedError()
156
+
157
+ if self.model_type == ModelType.NOISE:
158
+ terms['loss'] = mean_flat(weight * ((model_output - x0) ** 2))
159
+ else:
160
+ terms['loss'] = mean_flat(weight * ((model_output * sigma_t + x0) ** 2))
161
+
162
+ return terms
163
+
164
+
165
+ def get_drift(
166
+ self
167
+ ):
168
+ """member function for obtaining the drift of the probability flow ODE"""
169
+ def score_ode(x, t, model, **model_kwargs):
170
+ drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
171
+ model_output = model(x, t, **model_kwargs)
172
+ return (-drift_mean + drift_var * model_output) # by change of variable
173
+
174
+ def noise_ode(x, t, model, **model_kwargs):
175
+ drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
176
+ sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
177
+ model_output = model(x, t, **model_kwargs)
178
+ score = model_output / -sigma_t
179
+ return (-drift_mean + drift_var * score)
180
+
181
+ def velocity_ode(x, t, model, **model_kwargs):
182
+ model_output = model(x, t, **model_kwargs)
183
+ return model_output
184
+
185
+ if self.model_type == ModelType.NOISE:
186
+ drift_fn = noise_ode
187
+ elif self.model_type == ModelType.SCORE:
188
+ drift_fn = score_ode
189
+ else:
190
+ drift_fn = velocity_ode
191
+
192
+ def body_fn(x, t, model, **model_kwargs):
193
+ model_output = drift_fn(x, t, model, **model_kwargs)
194
+ assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
195
+ return model_output
196
+
197
+ return body_fn
198
+
199
+
200
+ def get_score(
201
+ self,
202
+ ):
203
+ """member function for obtaining score of
204
+ x_t = alpha_t * x + sigma_t * eps"""
205
+ if self.model_type == ModelType.NOISE:
206
+ score_fn = lambda x, t, model, **kwargs: model(x, t, **kwargs) / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
207
+ elif self.model_type == ModelType.SCORE:
208
+ score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
209
+ elif self.model_type == ModelType.VELOCITY:
210
+ score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(model(x, t, **kwargs), x, t)
211
+ else:
212
+ raise NotImplementedError()
213
+
214
+ return score_fn
215
+
216
+
217
+ class Sampler:
218
+ """Sampler class for the transport model"""
219
+ def __init__(
220
+ self,
221
+ transport,
222
+ ):
223
+ """Constructor for a general sampler; supporting different sampling methods
224
+ Args:
225
+ - transport: an tranport object specify model prediction & interpolant type
226
+ """
227
+
228
+ self.transport = transport
229
+ self.drift = self.transport.get_drift()
230
+ self.score = self.transport.get_score()
231
+
232
+ def __get_sde_diffusion_and_drift(
233
+ self,
234
+ *,
235
+ diffusion_form="SBDM",
236
+ diffusion_norm=1.0,
237
+ ):
238
+
239
+ def diffusion_fn(x, t):
240
+ diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)
241
+ return diffusion
242
+
243
+ sde_drift = \
244
+ lambda x, t, model, **kwargs: \
245
+ self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(x, t, model, **kwargs)
246
+
247
+ sde_diffusion = diffusion_fn
248
+
249
+ return sde_drift, sde_diffusion
250
+
251
+ def __get_last_step(
252
+ self,
253
+ sde_drift,
254
+ *,
255
+ last_step,
256
+ last_step_size,
257
+ ):
258
+ """Get the last step function of the SDE solver"""
259
+
260
+ if last_step is None:
261
+ last_step_fn = \
262
+ lambda x, t, model, **model_kwargs: \
263
+ x
264
+ elif last_step == "Mean":
265
+ last_step_fn = \
266
+ lambda x, t, model, **model_kwargs: \
267
+ x + sde_drift(x, t, model, **model_kwargs) * last_step_size
268
+ elif last_step == "Tweedie":
269
+ alpha = self.transport.path_sampler.compute_alpha_t # simple aliasing; the original name was too long
270
+ sigma = self.transport.path_sampler.compute_sigma_t
271
+ last_step_fn = \
272
+ lambda x, t, model, **model_kwargs: \
273
+ x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][0] * self.score(x, t, model, **model_kwargs)
274
+ elif last_step == "Euler":
275
+ last_step_fn = \
276
+ lambda x, t, model, **model_kwargs: \
277
+ x + self.drift(x, t, model, **model_kwargs) * last_step_size
278
+ else:
279
+ raise NotImplementedError()
280
+
281
+ return last_step_fn
282
+
283
+ def sample_sde(
284
+ self,
285
+ *,
286
+ sampling_method="Euler",
287
+ diffusion_form="SBDM",
288
+ diffusion_norm=1.0,
289
+ last_step="Mean",
290
+ last_step_size=0.04,
291
+ num_steps=250,
292
+ temperature=1.0,
293
+ ):
294
+ """returns a sampling function with given SDE settings
295
+ Args:
296
+ - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
297
+ - diffusion_form: function form of diffusion coefficient; default to be matching SBDM
298
+ - diffusion_norm: function magnitude of diffusion coefficient; default to 1
299
+ - last_step: type of the last step; default to identity
300
+ - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
301
+ - num_steps: total integration step of SDE
302
+ - temperature: temperature scaling for the noise during sampling; default to 1.0
303
+ """
304
+
305
+ if last_step is None:
306
+ last_step_size = 0.0
307
+
308
+ sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
309
+ diffusion_form=diffusion_form,
310
+ diffusion_norm=diffusion_norm,
311
+ )
312
+
313
+ t0, t1 = self.transport.check_interval(
314
+ self.transport.train_eps,
315
+ self.transport.sample_eps,
316
+ diffusion_form=diffusion_form,
317
+ sde=True,
318
+ eval=True,
319
+ reverse=False,
320
+ last_step_size=last_step_size,
321
+ )
322
+
323
+ _sde = sde(
324
+ sde_drift,
325
+ sde_diffusion,
326
+ t0=t0,
327
+ t1=t1,
328
+ num_steps=num_steps,
329
+ sampler_type=sampling_method,
330
+ temperature=temperature
331
+ )
332
+
333
+ last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
334
+
335
+
336
+ def _sample(init, model, **model_kwargs):
337
+ xs = _sde.sample(init, model, **model_kwargs)
338
+ ts = th.ones(init.size(0), device=init.device) * t1
339
+ x = last_step_fn(xs[-1], ts, model, **model_kwargs)
340
+ xs.append(x)
341
+
342
+ assert len(xs) == num_steps, "Samples does not match the number of steps"
343
+
344
+ return xs
345
+
346
+ return _sample
347
+
348
+ def sample_ode(
349
+ self,
350
+ *,
351
+ sampling_method="dopri5",
352
+ num_steps=50,
353
+ atol=1e-6,
354
+ rtol=1e-3,
355
+ reverse=False,
356
+ temperature=1.0,
357
+ ):
358
+ """returns a sampling function with given ODE settings
359
+ Args:
360
+ - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
361
+ - num_steps:
362
+ - fixed solver (Euler, Heun): the actual number of integration steps performed
363
+ - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
364
+ - atol: absolute error tolerance for the solver
365
+ - rtol: relative error tolerance for the solver
366
+ - reverse: whether solving the ODE in reverse (data to noise); default to False
367
+ - temperature: temperature scaling for the drift during sampling; default to 1.0
368
+ """
369
+ if reverse:
370
+ drift = lambda x, t, model, **kwargs: self.drift(x, th.ones_like(t) * (1 - t), model, **kwargs)
371
+ else:
372
+ drift = self.drift
373
+
374
+ t0, t1 = self.transport.check_interval(
375
+ self.transport.train_eps,
376
+ self.transport.sample_eps,
377
+ sde=False,
378
+ eval=True,
379
+ reverse=reverse,
380
+ last_step_size=0.0,
381
+ )
382
+
383
+ _ode = ode(
384
+ drift=drift,
385
+ t0=t0,
386
+ t1=t1,
387
+ sampler_type=sampling_method,
388
+ num_steps=num_steps,
389
+ atol=atol,
390
+ rtol=rtol,
391
+ temperature=temperature,
392
+ )
393
+
394
+ return _ode.sample
395
+
396
+ def sample_ode_likelihood(
397
+ self,
398
+ *,
399
+ sampling_method="dopri5",
400
+ num_steps=50,
401
+ atol=1e-6,
402
+ rtol=1e-3,
403
+ temperature=1.0,
404
+ ):
405
+
406
+ """returns a sampling function for calculating likelihood with given ODE settings
407
+ Args:
408
+ - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
409
+ - num_steps:
410
+ - fixed solver (Euler, Heun): the actual number of integration steps performed
411
+ - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
412
+ - atol: absolute error tolerance for the solver
413
+ - rtol: relative error tolerance for the solver
414
+ - temperature: temperature scaling for the drift during sampling; default to 1.0
415
+ """
416
+ def _likelihood_drift(x, t, model, **model_kwargs):
417
+ x, _ = x
418
+ eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
419
+ t = th.ones_like(t) * (1 - t)
420
+ with th.enable_grad():
421
+ x.requires_grad = True
422
+ grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
423
+ logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
424
+ drift = self.drift(x, t, model, **model_kwargs)
425
+ return (-drift, logp_grad)
426
+
427
+ t0, t1 = self.transport.check_interval(
428
+ self.transport.train_eps,
429
+ self.transport.sample_eps,
430
+ sde=False,
431
+ eval=True,
432
+ reverse=False,
433
+ last_step_size=0.0,
434
+ )
435
+
436
+ _ode = ode(
437
+ drift=_likelihood_drift,
438
+ t0=t0,
439
+ t1=t1,
440
+ sampler_type=sampling_method,
441
+ num_steps=num_steps,
442
+ atol=atol,
443
+ rtol=rtol,
444
+ temperature=temperature,
445
+ )
446
+
447
+ def _sample_fn(x, model, **model_kwargs):
448
+ init_logp = th.zeros(x.size(0)).to(x)
449
+ input = (x, init_logp)
450
+ drift, delta_logp = _ode.sample(input, model, **model_kwargs)
451
+ drift, delta_logp = drift[-1], delta_logp[-1]
452
+ prior_logp = self.transport.prior_logp(drift)
453
+ logp = prior_logp - delta_logp
454
+ return logp, drift
455
+
456
+ return _sample_fn
semanticist/stage1/transport/utils.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch as th
2
+
3
+ class EasyDict:
4
+
5
+ def __init__(self, sub_dict):
6
+ for k, v in sub_dict.items():
7
+ setattr(self, k, v)
8
+
9
+ def __getitem__(self, key):
10
+ return getattr(self, key)
11
+
12
+ def mean_flat(x):
13
+ """
14
+ Take the mean over all non-batch dimensions.
15
+ """
16
+ return th.mean(x, dim=list(range(1, len(x.size()))))
17
+
18
+ def log_state(state):
19
+ result = []
20
+
21
+ sorted_state = dict(sorted(state.items()))
22
+ for key, value in sorted_state.items():
23
+ # Check if the value is an instance of a class
24
+ if "<object" in str(value) or "object at" in str(value):
25
+ result.append(f"{key}: [{value.__class__.__name__}]")
26
+ else:
27
+ result.append(f"{key}: {value}")
28
+
29
+ return '\n'.join(result)
semanticist/stage1/vision_transformer.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Mostly copy-paste from timm library.
16
+ https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
17
+ """
18
+ import math
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from functools import partial
23
+ from semanticist.stage1.fused_attention import Attention
24
+
25
+ __all__ = ['VisionTransformer', 'vit_tiny_patch16', 'vit_small_patch16',
26
+ 'vit_base_patch16', 'vit_large_patch16', 'vit_huge_patch14']
27
+
28
+
29
+ def drop_path(x, drop_prob: float = 0., training: bool = False):
30
+ if drop_prob == 0. or not training:
31
+ return x
32
+ keep_prob = 1 - drop_prob
33
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
34
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
35
+ if keep_prob > 0.0:
36
+ random_tensor.div_(keep_prob)
37
+ return x * random_tensor
38
+
39
+
40
+ class DropPath(nn.Module):
41
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
42
+ """
43
+
44
+ def __init__(self, drop_prob=None):
45
+ super(DropPath, self).__init__()
46
+ self.drop_prob = drop_prob
47
+
48
+ def forward(self, x):
49
+ return drop_path(x, self.drop_prob, self.training)
50
+
51
+
52
+ class Mlp(nn.Module):
53
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
54
+ super().__init__()
55
+ out_features = out_features or in_features
56
+ hidden_features = hidden_features or in_features
57
+ self.fc1 = nn.Linear(in_features, hidden_features)
58
+ self.act = act_layer()
59
+ self.fc2 = nn.Linear(hidden_features, out_features)
60
+ self.drop = nn.Dropout(drop)
61
+
62
+ def forward(self, x):
63
+ x = self.fc1(x)
64
+ x = self.act(x)
65
+ x = self.drop(x)
66
+ x = self.fc2(x)
67
+ x = self.drop(x)
68
+ return x
69
+
70
+
71
+ class Block(nn.Module):
72
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0.,
73
+ attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, init_values=0):
74
+ super().__init__()
75
+ self.norm1 = norm_layer(dim)
76
+ self.attn = Attention(
77
+ dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
78
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
79
+ self.norm2 = norm_layer(dim)
80
+ mlp_hidden_dim = int(dim * mlp_ratio)
81
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
82
+
83
+ if init_values > 0:
84
+ self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
85
+ self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
86
+ else:
87
+ self.gamma_1, self.gamma_2 = None, None
88
+
89
+ def forward(self, x, attn_mask=None):
90
+ y = self.attn(self.norm1(x), attn_mask=attn_mask)
91
+ if self.gamma_1 is None:
92
+ x = x + self.drop_path(y)
93
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
94
+ else:
95
+ x = x + self.drop_path(self.gamma_1 * y)
96
+ x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
97
+ return x
98
+
99
+
100
+ class PatchEmbed(nn.Module):
101
+ """ Image to Patch Embedding
102
+ """
103
+
104
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
105
+ super().__init__()
106
+ num_patches = (img_size // patch_size) * (img_size // patch_size)
107
+ self.img_size = img_size
108
+ self.patch_size = patch_size
109
+ self.num_patches = num_patches
110
+
111
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
112
+
113
+ def forward(self, x):
114
+ B, C, H, W = x.shape
115
+ return self.proj(x)
116
+
117
+
118
+ class VisionTransformer(nn.Module):
119
+ """ Vision Transformer """
120
+
121
+ def __init__(self, img_size=[224], patch_size=16, in_chans=3, embed_dim=768, depth=12,
122
+ num_heads=12, mlp_ratio=4., qkv_bias=False, drop_rate=0., attn_drop_rate=0.,
123
+ drop_path_rate=0., norm_layer=partial(nn.LayerNorm, eps=1e-6),
124
+ init_values=0, num_slots=16):
125
+ super().__init__()
126
+ self.num_features = self.embed_dim = embed_dim
127
+
128
+ self.patch_embed = PatchEmbed(
129
+ img_size=img_size[0], patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
130
+ num_patches = self.patch_embed.num_patches
131
+
132
+ self.num_slots = num_slots
133
+
134
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
135
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1 + self.num_slots, embed_dim))
136
+ self.slot_embed = nn.Parameter(torch.zeros(1, num_slots, embed_dim))
137
+
138
+ self.pos_drop = nn.Dropout(p=drop_rate)
139
+
140
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
141
+ self.blocks = nn.ModuleList([
142
+ Block(
143
+ dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
144
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer,
145
+ init_values=init_values)
146
+ for i in range(depth)])
147
+
148
+ self.norm = norm_layer(embed_dim)
149
+
150
+ nn.init.trunc_normal_(self.pos_embed, std=.02)
151
+ nn.init.trunc_normal_(self.cls_token, std=.02)
152
+ nn.init.trunc_normal_(self.slot_embed, std=.02)
153
+ self.apply(self._init_weights)
154
+
155
+ def _init_weights(self, m):
156
+ if isinstance(m, nn.Linear):
157
+ nn.init.trunc_normal_(m.weight, std=.02)
158
+ if isinstance(m, nn.Linear) and m.bias is not None:
159
+ nn.init.constant_(m.bias, 0)
160
+ elif isinstance(m, nn.LayerNorm):
161
+ nn.init.constant_(m.bias, 0)
162
+ nn.init.constant_(m.weight, 1.0)
163
+
164
+ def interpolate_pos_encoding(self, x, w, h):
165
+ npatch = x.shape[1] - 1 - self.num_slots
166
+ N = self.pos_embed.shape[1] - 1 - self.num_slots
167
+ if npatch == N and w == h:
168
+ return self.pos_embed
169
+ class_pos_embed = self.pos_embed[:, 0]
170
+ patch_pos_embed = self.pos_embed[:, 1:1+npatch]
171
+ dim = x.shape[-1]
172
+ w0 = w // self.patch_embed.patch_size[0]
173
+ h0 = h // self.patch_embed.patch_size[1]
174
+ # we add a small number to avoid floating point error in the interpolation
175
+ # see discussion at https://github.com/facebookresearch/dino/issues/8
176
+ w0, h0 = w0 + 0.1, h0 + 0.1
177
+ patch_pos_embed = nn.functional.interpolate(
178
+ patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
179
+ scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
180
+ mode='bicubic',
181
+ )
182
+ assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
183
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
184
+
185
+ slots_pos_embed = self.pos_embed[:, 1+npatch:]
186
+ slots_pos_embed = slots_pos_embed.view(1, -1, dim) # (1, num_slots, dim)
187
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed, slots_pos_embed), dim=1)
188
+
189
+ def prepare_tokens(self, x):
190
+ B, nc, w, h = x.shape
191
+ x = self.patch_embed(x)
192
+ x = x.flatten(2).transpose(1, 2)
193
+ x = torch.cat((self.cls_token.expand(B, -1, -1), x, self.slot_embed.expand(B, -1, -1)), dim=1)
194
+ x = x + self.interpolate_pos_encoding(x, w, h)
195
+ return self.pos_drop(x)
196
+
197
+ def forward(self, x, is_causal=True):
198
+ x = self.prepare_tokens(x)
199
+ if is_causal:
200
+ attn_mask = torch.ones(x.shape[1], x.shape[1], device=x.device, dtype=torch.bool)
201
+ # slots are causal to each other
202
+ causal_mask = torch.ones(self.num_slots, self.num_slots, device=x.device, dtype=torch.bool).tril(diagonal=0)
203
+ attn_mask[-self.num_slots:, -self.num_slots:] = causal_mask
204
+ # cls token and patches should not see slots
205
+ attn_mask[:-self.num_slots, -self.num_slots:] = False
206
+ else:
207
+ attn_mask = None
208
+
209
+ for blk in self.blocks:
210
+ x = blk(x, attn_mask=attn_mask)
211
+
212
+ x = self.norm(x)
213
+ outcome = x[:, -self.num_slots:] # return the slots
214
+ return outcome
215
+
216
+ def get_intermediate_layers(self, x, n=1):
217
+ x = self.prepare_tokens(x)
218
+ # we return the output tokens from the `n` last blocks
219
+ output = []
220
+ for i, blk in enumerate(self.blocks):
221
+ x = blk(x)
222
+ if len(self.blocks) - i <= n:
223
+ output.append(self.norm(x))
224
+ return output
225
+
226
+
227
+ def vit_tiny_patch16(**kwargs):
228
+ model = VisionTransformer(
229
+ patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
230
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
231
+ return model
232
+
233
+
234
+ def vit_small_patch16(**kwargs):
235
+ model = VisionTransformer(
236
+ patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
237
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
238
+ return model
239
+
240
+
241
+ def vit_base_patch16(**kwargs):
242
+ model = VisionTransformer(
243
+ patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
244
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
245
+ return model
246
+
247
+
248
+ def vit_large_patch16(**kwargs):
249
+ model = VisionTransformer(
250
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
251
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
252
+ return model
253
+
254
+
255
+ def vit_huge_patch14(**kwargs):
256
+ model = VisionTransformer(
257
+ patch_size=14, embed_dim=1280, depth=32, num_heads=16, mlp_ratio=4, qkv_bias=True,
258
+ norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
259
+ return model
semanticist/stage2/diffloss.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ import torch.nn as nn
4
+
5
+ from semanticist.stage1.diffusion import create_diffusion
6
+ from semanticist.stage1.transport import create_transport, Sampler
7
+
8
+
9
+ class DiffLoss(nn.Module):
10
+ """Diffusion Loss"""
11
+ def __init__(self, target_channels, z_channels, depth, width, num_sampling_steps, predict_xstart=False, use_si=False, cond_method="adaln"):
12
+ super(DiffLoss, self).__init__()
13
+ self.in_channels = target_channels
14
+ self.net = SimpleMLPAdaLN(
15
+ in_channels=target_channels,
16
+ model_channels=width,
17
+ out_channels=target_channels * 2 if not use_si else target_channels, # for vlb loss
18
+ z_channels=z_channels,
19
+ num_res_blocks=depth,
20
+ cond_method=cond_method,
21
+ )
22
+ self.use_si = use_si
23
+ if not use_si:
24
+ self.train_diffusion = create_diffusion(timestep_respacing="", noise_schedule="cosine", predict_xstart=predict_xstart)
25
+ self.gen_diffusion = create_diffusion(timestep_respacing=num_sampling_steps, noise_schedule="cosine", predict_xstart=predict_xstart)
26
+ else:
27
+ self.transport = create_transport()
28
+ self.sampler = Sampler(self.transport)
29
+
30
+ def forward(self, target, z, mask=None):
31
+ model_kwargs = dict(c=z)
32
+ if not self.use_si:
33
+ t = torch.randint(0, self.train_diffusion.num_timesteps, (target.shape[0],), device=target.device)
34
+ loss_dict = self.train_diffusion.training_losses(self.net, target, t, model_kwargs)
35
+ else:
36
+ loss_dict = self.transport.training_losses(self.net, target, model_kwargs)
37
+ loss = loss_dict["loss"]
38
+ if mask is not None:
39
+ loss = (loss * mask).sum() / mask.sum()
40
+ return loss.mean()
41
+
42
+ def sample(self, z, temperature=1.0, cfg=1.0):
43
+ # diffusion loss sampling
44
+ device = z.device
45
+ if not cfg == 1.0:
46
+ noise = torch.randn(z.shape[0] // 2, self.in_channels, device=device)
47
+ noise = torch.cat([noise, noise], dim=0)
48
+ model_kwargs = dict(c=z, cfg_scale=cfg)
49
+ sample_fn = self.net.forward_with_cfg
50
+ else:
51
+ noise = torch.randn(z.shape[0], self.in_channels, device=device)
52
+ model_kwargs = dict(c=z)
53
+ sample_fn = self.net.forward
54
+
55
+ if not self.use_si:
56
+ sampled_token_latent = self.gen_diffusion.p_sample_loop(
57
+ sample_fn, noise.shape, noise, clip_denoised=False, model_kwargs=model_kwargs, progress=False,
58
+ temperature=temperature, device=device
59
+ )
60
+ else:
61
+ sde_sample_fn = self.sampler.sample_sde(diffusion_form="sigma", temperature=temperature)
62
+ sampled_token_latent = sde_sample_fn(noise, sample_fn, **model_kwargs)[-1]
63
+ if cfg != 1.0:
64
+ sampled_token_latent, _ = sampled_token_latent.chunk(2, dim=0)
65
+ return sampled_token_latent
66
+
67
+
68
+ def modulate(x, shift, scale):
69
+ return x * (1 + scale) + shift
70
+
71
+
72
+ class TimestepEmbedder(nn.Module):
73
+ """
74
+ Embeds scalar timesteps into vector representations.
75
+ """
76
+ def __init__(self, hidden_size, frequency_embedding_size=256):
77
+ super().__init__()
78
+ self.mlp = nn.Sequential(
79
+ nn.Linear(frequency_embedding_size, hidden_size, bias=True),
80
+ nn.SiLU(),
81
+ nn.Linear(hidden_size, hidden_size, bias=True),
82
+ )
83
+ self.frequency_embedding_size = frequency_embedding_size
84
+
85
+ @staticmethod
86
+ def timestep_embedding(t, dim, max_period=10000):
87
+ """
88
+ Create sinusoidal timestep embeddings.
89
+ :param t: a 1-D Tensor of N indices, one per batch element.
90
+ These may be fractional.
91
+ :param dim: the dimension of the output.
92
+ :param max_period: controls the minimum frequency of the embeddings.
93
+ :return: an (N, D) Tensor of positional embeddings.
94
+ """
95
+ # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
96
+ half = dim // 2
97
+ freqs = torch.exp(
98
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
99
+ ).to(device=t.device)
100
+ args = t[:, None].float() * freqs[None]
101
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
102
+ if dim % 2:
103
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
104
+ return embedding
105
+
106
+ def forward(self, t):
107
+ t_freq = self.timestep_embedding(t, self.frequency_embedding_size)
108
+ t_emb = self.mlp(t_freq)
109
+ return t_emb
110
+
111
+
112
+ class ResBlock(nn.Module):
113
+ """
114
+ A residual block with AdaLN for timestep and optional concatenation for condition.
115
+ """
116
+ def __init__(
117
+ self,
118
+ channels,
119
+ cond_method="adaln",
120
+ ):
121
+ super().__init__()
122
+ self.channels = channels
123
+ self.cond_method = cond_method
124
+
125
+ self.in_ln = nn.LayerNorm(channels, eps=1e-6)
126
+ self.adaLN_modulation = nn.Sequential(
127
+ nn.SiLU(),
128
+ nn.Linear(channels, 3 * channels, bias=True)
129
+ )
130
+
131
+ # Input dimension depends on conditioning method
132
+ mlp_in_dim = channels * 2 if cond_method == "concat" else channels
133
+ self.mlp = nn.Sequential(
134
+ nn.Linear(mlp_in_dim, channels, bias=True),
135
+ nn.SiLU(),
136
+ nn.Linear(channels, channels, bias=True),
137
+ )
138
+
139
+ def forward(self, x, t, c=None):
140
+ # Apply timestep embedding via AdaLN
141
+ shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(t).chunk(3, dim=-1)
142
+ h = modulate(self.in_ln(x), shift_mlp, scale_mlp)
143
+
144
+ # Concatenate condition if using concat method
145
+ if self.cond_method == "concat" and c is not None:
146
+ h = torch.cat([h, c], dim=-1)
147
+
148
+ h = self.mlp(h)
149
+ x = x + gate_mlp * h
150
+ return x
151
+
152
+
153
+ class FinalLayer(nn.Module):
154
+ """
155
+ Final layer with AdaLN for timestep and optional concatenation for condition.
156
+ """
157
+ def __init__(self, model_channels, out_channels, cond_method="adaln"):
158
+ super().__init__()
159
+ self.norm_final = nn.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6)
160
+ self.cond_method = cond_method
161
+
162
+ self.adaLN_modulation = nn.Sequential(
163
+ nn.SiLU(),
164
+ nn.Linear(model_channels, 2 * model_channels, bias=True)
165
+ )
166
+
167
+ # Output dimension depends on conditioning method
168
+ linear_in_dim = model_channels * 2 if cond_method == "concat" else model_channels
169
+ self.linear = nn.Linear(linear_in_dim, out_channels, bias=True)
170
+
171
+ def forward(self, x, t, c=None):
172
+ # Apply timestep embedding via AdaLN
173
+ shift, scale = self.adaLN_modulation(t).chunk(2, dim=-1)
174
+ x = modulate(self.norm_final(x), shift, scale)
175
+
176
+ # Concatenate condition if using concat method
177
+ if self.cond_method == "concat" and c is not None:
178
+ x = torch.cat([x, c], dim=-1)
179
+
180
+ return self.linear(x)
181
+
182
+
183
+ class SimpleMLPAdaLN(nn.Module):
184
+ """
185
+ MLP for Diffusion Loss with AdaLN for timestep and optional concatenation for condition.
186
+ """
187
+ def __init__(
188
+ self,
189
+ in_channels,
190
+ model_channels,
191
+ out_channels,
192
+ z_channels,
193
+ num_res_blocks,
194
+ cond_method="adaln"
195
+ ):
196
+ super().__init__()
197
+ self.in_channels = in_channels
198
+ self.model_channels = model_channels
199
+ self.out_channels = out_channels
200
+ self.cond_method = cond_method
201
+
202
+ self.time_embed = TimestepEmbedder(model_channels)
203
+ self.cond_embed = nn.Linear(z_channels, model_channels)
204
+ self.input_proj = nn.Linear(in_channels, model_channels)
205
+
206
+ # Create residual blocks
207
+ res_blocks = [ResBlock(model_channels, cond_method) for _ in range(num_res_blocks)]
208
+ self.res_blocks = nn.ModuleList(res_blocks)
209
+
210
+ self.final_layer = FinalLayer(model_channels, out_channels, cond_method=cond_method)
211
+ self.initialize_weights()
212
+
213
+ def initialize_weights(self):
214
+ # Basic initialization for all linear layers
215
+ def _basic_init(module):
216
+ if isinstance(module, nn.Linear):
217
+ torch.nn.init.xavier_uniform_(module.weight)
218
+ if module.bias is not None:
219
+ nn.init.constant_(module.bias, 0)
220
+ self.apply(_basic_init)
221
+
222
+ # Initialize timestep embedding MLP
223
+ nn.init.normal_(self.time_embed.mlp[0].weight, std=0.02)
224
+ nn.init.normal_(self.time_embed.mlp[2].weight, std=0.02)
225
+
226
+ # Zero-out adaLN modulation layers (always used for timestep)
227
+ for i, block in enumerate(self.res_blocks):
228
+ nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
229
+ nn.init.constant_(block.adaLN_modulation[-1].bias, 0)
230
+
231
+ # Zero-out output layers
232
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0)
233
+ nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0)
234
+ nn.init.constant_(self.final_layer.linear.weight, 0)
235
+ nn.init.constant_(self.final_layer.linear.bias, 0)
236
+
237
+ def forward(self, x, t, c):
238
+ """
239
+ Apply the model to an input batch.
240
+ :param x: an [N x C] Tensor of inputs.
241
+ :param t: a 1-D batch of timesteps.
242
+ :param c: conditioning from AR transformer.
243
+ :return: an [N x C] Tensor of outputs.
244
+ """
245
+ x = self.input_proj(x)
246
+ t_emb = self.time_embed(t)
247
+ c_emb = self.cond_embed(c)
248
+
249
+ # Prepare conditioning based on method
250
+ if self.cond_method == "adaln":
251
+ t_combined, c_for_concat = t_emb + c_emb, None
252
+ else: # concat
253
+ t_combined, c_for_concat = t_emb, c_emb
254
+
255
+ for block in self.res_blocks:
256
+ x = block(x, t_combined, c_for_concat)
257
+ return self.final_layer(x, t_combined, c_for_concat)
258
+
259
+ def forward_with_cfg(self, x, t, c, cfg_scale):
260
+ half = x[: len(x) // 2]
261
+ combined = torch.cat([half, half], dim=0)
262
+ model_out = self.forward(combined, t, c)
263
+ eps, rest = model_out[:, :self.in_channels], model_out[:, self.in_channels:]
264
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
265
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
266
+ eps = torch.cat([half_eps, half_eps], dim=0)
267
+ return torch.cat([eps, rest], dim=1)
semanticist/stage2/generate.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/generate.py
3
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
4
+ import torch
5
+
6
+ def prefill(model, cond_idx: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, temperature: float = 1.0):
7
+ tokens = model(None, cond_idx, input_pos, cfg=cfg_scale, temperature=temperature)
8
+ return tokens.unsqueeze(1)
9
+
10
+
11
+ def decode_one_token(model, x: torch.Tensor, input_pos: torch.Tensor, cfg_scale: float, temperature: float = 1.0):
12
+ assert input_pos.shape[-1] == 1
13
+ if cfg_scale > 1.0:
14
+ x = torch.cat([x, x])
15
+ tokens = model(x, cond_idx=None, input_pos=input_pos, cfg=cfg_scale, temperature=temperature)
16
+ return tokens
17
+
18
+
19
+ def decode_n_tokens(
20
+ model, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int,
21
+ cfg_scale: float, cfg_schedule = "constant", temperature: float = 1.0):
22
+ new_tokens = []
23
+ for i in range(num_new_tokens):
24
+ cfg_iter = get_cfg(cfg_scale, i + 1, num_new_tokens + 1, cfg_schedule)
25
+ next_token = decode_one_token(model, cur_token, input_pos, cfg_iter, temperature=temperature).unsqueeze(1)
26
+ input_pos += 1
27
+ new_tokens.append(next_token.clone())
28
+ cur_token = next_token
29
+
30
+ return new_tokens
31
+
32
+
33
+ @torch.no_grad()
34
+ def generate(model, cond, max_new_tokens, emb_masks=None, cfg_scale=1.0, cfg_schedule = "constant", temperature: float = 1.0):
35
+ if cfg_scale > 1.0:
36
+ cond_null = torch.ones_like(cond) * model.num_classes
37
+ cond_combined = torch.cat([cond, cond_null])
38
+ else:
39
+ cond_combined = cond
40
+ T = model.cls_token_num
41
+
42
+ T_new = T + max_new_tokens
43
+ max_seq_length = T_new
44
+ max_batch_size = cond.shape[0]
45
+
46
+ device = cond.device
47
+ dtype = model.z_proj.weight.dtype
48
+ if torch.is_autocast_enabled():
49
+ dtype = torch.get_autocast_dtype(device_type=device.type)
50
+ with torch.device(device):
51
+ max_batch_size_cfg = max_batch_size * 2 if cfg_scale > 1.0 else max_batch_size
52
+ model.setup_caches(max_batch_size=max_batch_size_cfg, max_seq_length=max_seq_length, dtype=dtype)
53
+
54
+ if emb_masks is not None:
55
+ assert emb_masks.shape[0] == max_batch_size
56
+ assert emb_masks.shape[-1] == T
57
+ if cfg_scale > 1.0:
58
+ model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * torch.cat([emb_masks, emb_masks]).unsqueeze(1)
59
+ else:
60
+ model.causal_mask[:, :, :T] = model.causal_mask[:, :, :T] * emb_masks.unsqueeze(1)
61
+
62
+ eye_matrix = torch.eye(model.causal_mask.size(1), model.causal_mask.size(2), device=device)
63
+ model.causal_mask[:] = model.causal_mask * (1 - eye_matrix) + eye_matrix
64
+
65
+ # create an empty tensor of the expected final shape and fill in the current tokens
66
+ seq = torch.empty((max_batch_size, T_new, model.slot_dim), dtype=dtype, device=device)
67
+
68
+ input_pos = torch.arange(0, T, device=device)
69
+ cfg_iter = get_cfg(cfg_scale, 0, max_new_tokens, cfg_schedule)
70
+ next_token = prefill(model, cond_combined, input_pos, cfg_iter, temperature=temperature)
71
+ seq[:, T:T+1] = next_token
72
+
73
+ if max_new_tokens > 1:
74
+ input_pos = torch.tensor([T], device=device, dtype=torch.int)
75
+ generated_tokens = decode_n_tokens(model, next_token, input_pos, max_new_tokens - 1, cfg_scale, cfg_schedule=cfg_schedule, temperature=temperature)
76
+ seq[:, T+1:] = torch.cat(generated_tokens, dim=1)
77
+
78
+ model.reset_caches()
79
+ return seq[:, T:]
80
+
81
+
82
+ def get_cfg(cfg, cur_step, total_step, cfg_schedule="constant"):
83
+ if cfg_schedule == "linear":
84
+ return 1 + (cfg - 1) * (cur_step + 1) / total_step
85
+ elif cfg_schedule == "constant":
86
+ return cfg
87
+ else:
88
+ raise NotImplementedError
semanticist/stage2/gpt.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Modified from:
2
+ # VQGAN: https://github.com/CompVis/taming-transformers/blob/master/taming/modules/transformer/mingpt.py
3
+ # DiT: https://github.com/facebookresearch/DiT/blob/main/models.py
4
+ # nanoGPT: https://github.com/karpathy/nanoGPT/blob/master/model.py
5
+ # llama: https://github.com/facebookresearch/llama/blob/main/llama/model.py
6
+ # gpt-fast: https://github.com/pytorch-labs/gpt-fast/blob/main/model.py
7
+ # PixArt: https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
8
+ from typing import Optional, List, Union
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.nn import functional as F
13
+
14
+ from semanticist.stage1.vision_transformer import DropPath
15
+ from semanticist.stage2.diffloss import DiffLoss
16
+
17
+ def find_multiple(n: int, k: int):
18
+ if n % k == 0:
19
+ return n
20
+ return n + k - (n % k)
21
+
22
+
23
+
24
+ #################################################################################
25
+ # Embedding Layers for Class Labels #
26
+ #################################################################################
27
+ class LabelEmbedder(nn.Module):
28
+ """
29
+ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
30
+ """
31
+ def __init__(self, num_classes, hidden_size, dropout_prob):
32
+ super().__init__()
33
+ use_cfg_embedding = dropout_prob > 0
34
+ self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size)
35
+ self.num_classes = num_classes
36
+ self.dropout_prob = dropout_prob
37
+
38
+ def token_drop(self, labels, force_drop_ids=None):
39
+ """
40
+ Drops labels to enable classifier-free guidance.
41
+ """
42
+ if force_drop_ids is None:
43
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
44
+ else:
45
+ drop_ids = force_drop_ids == 1
46
+ labels = torch.where(drop_ids, self.num_classes, labels)
47
+ return labels
48
+
49
+ def forward(self, labels, train, force_drop_ids=None):
50
+ use_dropout = self.dropout_prob > 0
51
+ if (train and use_dropout) or (force_drop_ids is not None):
52
+ labels = self.token_drop(labels, force_drop_ids)
53
+ embeddings = self.embedding_table(labels).unsqueeze(1)
54
+ return embeddings
55
+
56
+
57
+ class MLP(nn.Module):
58
+ def __init__(self, in_features, hidden_features, out_features):
59
+ super().__init__()
60
+ out_features = out_features or in_features
61
+ hidden_features = hidden_features or in_features
62
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=False)
63
+ self.act = nn.GELU(approximate='tanh')
64
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=False)
65
+
66
+ def forward(self, x):
67
+ x = self.fc1(x)
68
+ x = self.act(x)
69
+ x = self.fc2(x)
70
+ return x
71
+
72
+
73
+ #################################################################################
74
+ # GPT Model #
75
+ #################################################################################
76
+ class RMSNorm(torch.nn.Module):
77
+ def __init__(self, dim: int, eps: float = 1e-5):
78
+ super().__init__()
79
+ self.eps = eps
80
+ self.weight = nn.Parameter(torch.ones(dim))
81
+
82
+ def _norm(self, x):
83
+ return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
84
+
85
+ def forward(self, x):
86
+ output = self._norm(x.float()).type_as(x)
87
+ return output * self.weight
88
+
89
+
90
+ class FeedForward(nn.Module):
91
+ def __init__(
92
+ self,
93
+ dim: int,
94
+ multiple_of: int = 256,
95
+ ffn_dropout_p: float = 0.0,
96
+ ):
97
+ super().__init__()
98
+ hidden_dim = 4 * dim
99
+ hidden_dim = int(2 * hidden_dim / 3)
100
+ hidden_dim = find_multiple(hidden_dim, multiple_of)
101
+
102
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
103
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
104
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
105
+ self.ffn_dropout = nn.Dropout(ffn_dropout_p)
106
+
107
+ def forward(self, x):
108
+ return self.ffn_dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
109
+
110
+
111
+ class KVCache(nn.Module):
112
+ def __init__(self, max_batch_size, max_seq_length, n_head, head_dim, dtype):
113
+ super().__init__()
114
+ cache_shape = (max_batch_size, n_head, max_seq_length, head_dim)
115
+ self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
116
+ self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
117
+
118
+ def update(self, input_pos, k_val, v_val):
119
+ # input_pos: [S], k_val: [B, H, S, D]
120
+ assert input_pos.shape[0] == k_val.shape[2]
121
+ k_out = self.k_cache
122
+ v_out = self.v_cache
123
+ k_out[:, :, input_pos] = k_val
124
+ v_out[:, :, input_pos] = v_val
125
+
126
+ return k_out, v_out
127
+
128
+
129
+ class Attention(nn.Module):
130
+ def __init__(
131
+ self,
132
+ dim: int,
133
+ n_head: int,
134
+ attn_dropout_p: float = 0.0,
135
+ resid_dropout_p: float = 0.1,
136
+ ):
137
+ super().__init__()
138
+ assert dim % n_head == 0
139
+ self.dim = dim
140
+ self.head_dim = dim // n_head
141
+ self.n_head = n_head
142
+
143
+ # key, query, value projections for all heads, but in a batch
144
+ self.wqkv = nn.Linear(dim, dim * 3, bias=False)
145
+ self.wo = nn.Linear(dim, dim, bias=False)
146
+ self.kv_cache = None
147
+
148
+ # regularization
149
+ self.attn_dropout_p = attn_dropout_p
150
+ self.resid_dropout = nn.Dropout(resid_dropout_p)
151
+
152
+ def forward(
153
+ self, x: torch.Tensor,
154
+ input_pos: Optional[torch.Tensor] = None,
155
+ mask: Optional[torch.Tensor] = None
156
+ ):
157
+ bsz, seqlen, _ = x.shape
158
+ xq, xk, xv = self.wqkv(x).split([self.dim, self.dim, self.dim], dim=-1)
159
+
160
+ xq = xq.view(bsz, seqlen, self.n_head, self.head_dim)
161
+ xk = xk.view(bsz, seqlen, self.n_head, self.head_dim)
162
+ xv = xv.view(bsz, seqlen, self.n_head, self.head_dim)
163
+
164
+ xq, xk, xv = map(lambda x: x.transpose(1, 2), (xq, xk, xv))
165
+
166
+ if self.kv_cache is not None:
167
+ keys, values = self.kv_cache.update(input_pos, xk, xv)
168
+ else:
169
+ keys, values = xk, xv
170
+
171
+ output = F.scaled_dot_product_attention(
172
+ xq, keys, values,
173
+ attn_mask=mask,
174
+ is_causal=True if mask is None else False, # is_causal=False is for KV cache
175
+ dropout_p=self.attn_dropout_p if self.training else 0)
176
+
177
+ output = output.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
178
+
179
+ output = self.resid_dropout(self.wo(output))
180
+ return output
181
+
182
+
183
+ class TransformerBlock(nn.Module):
184
+ def __init__(
185
+ self,
186
+ dim: int,
187
+ n_head: int,
188
+ multiple_of: int = 256,
189
+ norm_eps: float = 1e-5,
190
+ attn_dropout_p: float = 0.0,
191
+ ffn_dropout_p: float = 0.1,
192
+ resid_dropout_p: float = 0.1,
193
+ drop_path: float = 0.0,
194
+ ):
195
+ super().__init__()
196
+ self.attention = Attention(
197
+ dim=dim,
198
+ n_head=n_head,
199
+ attn_dropout_p=attn_dropout_p,
200
+ resid_dropout_p=resid_dropout_p,
201
+ )
202
+ self.feed_forward = FeedForward(
203
+ dim=dim,
204
+ multiple_of=multiple_of,
205
+ ffn_dropout_p=ffn_dropout_p,
206
+ )
207
+ self.attention_norm = RMSNorm(dim, eps=norm_eps)
208
+ self.ffn_norm = RMSNorm(dim, eps=norm_eps)
209
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
210
+
211
+ def forward(self, x: torch.Tensor, start_pos: int, mask: Optional[torch.Tensor] = None):
212
+ h = x + self.drop_path(self.attention(self.attention_norm(x), start_pos, mask))
213
+ out = h + self.drop_path(self.feed_forward(self.ffn_norm(h)))
214
+ return out
215
+
216
+
217
+ class Transformer(nn.Module):
218
+ def __init__(
219
+ self,
220
+ dim: int = 4096,
221
+ n_layer: int = 32,
222
+ n_head: int = 32,
223
+ attn_dropout_p: float = 0.0,
224
+ resid_dropout_p: float = 0.1,
225
+ ffn_dropout_p: float = 0.1,
226
+ drop_path_rate: float = 0.0,
227
+ num_classes: Union[int, List[int]] = 1000,
228
+ class_dropout_prob: float = 0.1,
229
+
230
+ cls_token_num: int = 1,
231
+ num_slots: int = 16,
232
+ slot_dim: int = 256,
233
+
234
+ diffloss_d: int = 3,
235
+ diffloss_w: int = 1024,
236
+ num_sampling_steps: str = '100',
237
+ diffusion_batch_mul: int = 4,
238
+ predict_xstart: bool = False,
239
+ use_si: bool = False,
240
+ cond_method: str = "adaln",
241
+ **kwargs,
242
+ ):
243
+ super().__init__()
244
+
245
+ # Store configuration
246
+ self.dim = dim
247
+ self.n_layer = n_layer
248
+ self.n_head = n_head
249
+ self.num_slots = num_slots
250
+ self.slot_dim = slot_dim
251
+ self.num_classes = num_classes
252
+ self.cls_token_num = cls_token_num
253
+
254
+ # Initialize embeddings
255
+ self.cls_embedding = LabelEmbedder(num_classes, dim, class_dropout_prob)
256
+ self.z_proj = nn.Linear(slot_dim, dim, bias=True)
257
+ self.z_proj_ln = RMSNorm(dim)
258
+ self.pos_embed_learned = nn.Parameter(torch.zeros(1, num_slots + cls_token_num, dim))
259
+
260
+ # transformer blocks
261
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, n_layer)]
262
+ self.layers = torch.nn.ModuleList()
263
+ for layer_id in range(n_layer):
264
+ self.layers.append(TransformerBlock(
265
+ dim=dim,
266
+ n_head=n_head,
267
+ ffn_dropout_p=ffn_dropout_p,
268
+ attn_dropout_p=attn_dropout_p,
269
+ resid_dropout_p=resid_dropout_p,
270
+ drop_path=dpr[layer_id],
271
+ ))
272
+
273
+ # output layer
274
+ self.norm = RMSNorm(dim)
275
+
276
+ self.diffusion_pos_embed_learned = nn.Parameter(torch.zeros(1, num_slots, dim))
277
+
278
+ # KVCache
279
+ self.max_batch_size = -1
280
+ self.max_seq_length = -1
281
+
282
+ self.initialize_weights()
283
+
284
+ # Diffusion Loss
285
+ self.diffloss = DiffLoss(
286
+ target_channels=slot_dim,
287
+ z_channels=dim,
288
+ width=diffloss_w,
289
+ depth=diffloss_d,
290
+ num_sampling_steps=num_sampling_steps,
291
+ predict_xstart=predict_xstart,
292
+ use_si=use_si,
293
+ cond_method=cond_method,
294
+ )
295
+ self.diffusion_batch_mul = diffusion_batch_mul
296
+
297
+ def initialize_weights(self):
298
+ nn.init.normal_(self.pos_embed_learned, std=0.02)
299
+ nn.init.normal_(self.diffusion_pos_embed_learned, std=0.02)
300
+ # Initialize nn.Linear and nn.Embedding
301
+ self.apply(self._init_weights)
302
+
303
+ def _init_weights(self, module):
304
+ if isinstance(module, nn.Linear):
305
+ module.weight.data.normal_(std=0.02)
306
+ if module.bias is not None:
307
+ module.bias.data.zero_()
308
+ elif isinstance(module, nn.Embedding):
309
+ module.weight.data.normal_(std=0.02)
310
+
311
+ def setup_caches(self, max_batch_size, max_seq_length, dtype):
312
+ # if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
313
+ # return
314
+ head_dim = self.dim // self.n_head
315
+ max_seq_length = find_multiple(max_seq_length, 8)
316
+ self.max_seq_length = max_seq_length
317
+ self.max_batch_size = max_batch_size
318
+ for b in self.layers:
319
+ b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.n_head, head_dim, dtype)
320
+
321
+ causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
322
+ self.causal_mask = causal_mask.unsqueeze(0).repeat(self.max_batch_size, 1, 1)
323
+
324
+ def reset_caches(self):
325
+ self.max_seq_length = -1
326
+ self.max_batch_size = -1
327
+ for b in self.layers:
328
+ b.attention.kv_cache = None
329
+
330
+ def forward_loss(self, z, target):
331
+ bsz, seq_len, _ = target.shape
332
+ target = target.reshape(bsz * seq_len, -1).repeat(self.diffusion_batch_mul, 1)
333
+ z = z.reshape(bsz*seq_len, -1).repeat(self.diffusion_batch_mul, 1)
334
+ loss = self.diffloss(z=z, target=target)
335
+ return loss
336
+
337
+ def forward_cfg(self, h, cfg):
338
+ if cfg > 1.0:
339
+ h_cond, h_uncond = h.chunk(2, dim=0)
340
+ h = h_uncond + cfg * (h_cond - h_uncond)
341
+ return h
342
+
343
+ def forward(
344
+ self,
345
+ slots: torch.Tensor,
346
+ cond_idx: torch.Tensor,
347
+ input_pos: Optional[torch.Tensor] = None,
348
+ mask: Optional[torch.Tensor] = None,
349
+ cfg: float = 1.0,
350
+ temperature: float = 1.0
351
+ ):
352
+ if slots is not None and cond_idx is not None: # training or naive inference
353
+ cond_embeddings = self.cls_embedding(cond_idx, train=self.training)
354
+ cond_embeddings = cond_embeddings.expand(-1, self.cls_token_num, -1)
355
+ token_embeddings = self.z_proj(slots)
356
+ token_embeddings = torch.cat((cond_embeddings, token_embeddings), dim=1)
357
+ else:
358
+ if cond_idx is not None: # prefill in inference
359
+ token_embeddings = self.cls_embedding(cond_idx, train=self.training)
360
+ token_embeddings = token_embeddings.expand(-1, self.cls_token_num, -1)
361
+ else: # decode_n_tokens(kv cache) in inference
362
+ token_embeddings = self.z_proj(slots)
363
+
364
+ bs = token_embeddings.shape[0]
365
+ mask = self.causal_mask[:bs, None, input_pos]
366
+
367
+ h = token_embeddings
368
+ if self.training:
369
+ h = h + self.pos_embed_learned
370
+ else:
371
+ h = h + self.pos_embed_learned[:, input_pos].view(1, -1, self.dim)
372
+
373
+ h = self.z_proj_ln(h) # not sure if this is needed
374
+
375
+ # transformer blocks
376
+ for layer in self.layers:
377
+ h = layer(h, input_pos, mask)
378
+
379
+ h = self.norm(h)
380
+
381
+ if self.training:
382
+ h = h[:, self.cls_token_num - 1 : -1].contiguous()
383
+ h = h + self.diffusion_pos_embed_learned
384
+ loss = self.forward_loss(h, slots.detach())
385
+ return loss
386
+ else:
387
+ h = h[:, -1]
388
+ h = h + self.diffusion_pos_embed_learned[:, input_pos[-1] - self.cls_token_num + 1]
389
+ next_tokens = self.diffloss.sample(h, temperature=temperature, cfg=cfg)
390
+ return next_tokens
391
+
392
+
393
+ def get_fsdp_wrap_module_list(self) -> List[nn.Module]:
394
+ return list(self.layers)
395
+
396
+
397
+
398
+ #################################################################################
399
+ # GPT Configs #
400
+ #################################################################################
401
+ ### text-conditional
402
+ def GPT_7B(**kwargs):
403
+ return Transformer(n_layer=32, n_head=32, dim=4096, **kwargs) # 6.6B
404
+
405
+ def GPT_3B(**kwargs):
406
+ return Transformer(n_layer=24, n_head=32, dim=3200, **kwargs) # 3.1B
407
+
408
+ def GPT_1B(**kwargs):
409
+ return Transformer(n_layer=22, n_head=32, dim=2048, **kwargs) # 1.2B
410
+
411
+ ### class-conditional
412
+ def GPT_XXXL(**kwargs):
413
+ return Transformer(n_layer=48, n_head=40, dim=2560, **kwargs) # 3.9B
414
+
415
+ def GPT_XXL(**kwargs):
416
+ return Transformer(n_layer=48, n_head=24, dim=1536, **kwargs) # 1.4B
417
+
418
+ def GPT_XL(**kwargs):
419
+ return Transformer(n_layer=36, n_head=20, dim=1280, **kwargs) # 775M
420
+
421
+ def GPT_L(**kwargs):
422
+ return Transformer(n_layer=24, n_head=16, dim=1024, **kwargs) # 343M
423
+
424
+ def GPT_B(**kwargs):
425
+ return Transformer(n_layer=12, n_head=12, dim=768, **kwargs) # 111M
426
+
427
+
428
+ GPT_models = {
429
+ 'GPT-B': GPT_B, 'GPT-L': GPT_L, 'GPT-XL': GPT_XL, 'GPT-XXL': GPT_XXL, 'GPT-XXXL': GPT_XXXL,
430
+ 'GPT-1B': GPT_1B, 'GPT-3B': GPT_3B, 'GPT-7B': GPT_7B,
431
+ }
semanticist/utils/datasets.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import numpy as np
4
+ import os.path as osp
5
+ from PIL import Image
6
+ import torchvision
7
+ import torchvision.transforms as TF
8
+
9
+ def pair(t):
10
+ return t if isinstance(t, tuple) else (t, t)
11
+
12
+ def center_crop_arr(pil_image, image_size):
13
+ """
14
+ Center cropping implementation from ADM.
15
+ https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
16
+ """
17
+ while min(*pil_image.size) >= 2 * image_size:
18
+ pil_image = pil_image.resize(
19
+ tuple(x // 2 for x in pil_image.size), resample=Image.BOX
20
+ )
21
+
22
+ scale = image_size / min(*pil_image.size)
23
+ pil_image = pil_image.resize(
24
+ tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
25
+ )
26
+
27
+ arr = np.array(pil_image)
28
+ crop_y = (arr.shape[0] - image_size) // 2
29
+ crop_x = (arr.shape[1] - image_size) // 2
30
+ return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
31
+
32
+ def vae_transforms(split, aug='randcrop', img_size=256):
33
+ t = []
34
+ if split == 'train':
35
+ if aug == 'randcrop':
36
+ t.append(TF.Resize(img_size, interpolation=TF.InterpolationMode.BICUBIC, antialias=True))
37
+ t.append(TF.RandomCrop(img_size))
38
+ elif aug == 'centercrop':
39
+ t.append(TF.Lambda(lambda x: center_crop_arr(x, img_size)))
40
+ else:
41
+ raise ValueError(f"Invalid augmentation: {aug}")
42
+ t.append(TF.RandomHorizontalFlip(p=0.5))
43
+ else:
44
+ t.append(TF.Lambda(lambda x: center_crop_arr(x, img_size)))
45
+
46
+ t.append(TF.ToTensor())
47
+
48
+ return TF.Compose(t)
49
+
50
+
51
+ def cached_transforms(aug='tencrop', img_size=256, crop_ranges=[1.05, 1.10]):
52
+ t = []
53
+ if 'centercrop' in aug:
54
+ t.append(TF.Lambda(lambda x: center_crop_arr(x, img_size)))
55
+ t.append(TF.Lambda(lambda x: torch.stack([TF.ToTensor()(x), TF.ToTensor()(TF.functional.hflip(x))])))
56
+ elif 'tencrop' in aug:
57
+ crop_sizes = [int(img_size * crop_range) for crop_range in crop_ranges]
58
+ t.append(TF.Lambda(lambda x: [center_crop_arr(x, crop_size) for crop_size in crop_sizes]))
59
+ t.append(TF.Lambda(lambda crops: [crop for crop_tuple in [TF.TenCrop(img_size)(crop) for crop in crops] for crop in crop_tuple]))
60
+ t.append(TF.Lambda(lambda crops: torch.stack([TF.ToTensor()(crop) for crop in crops])))
61
+ else:
62
+ raise ValueError(f"Invalid augmentation: {aug}")
63
+
64
+ return TF.Compose(t)
65
+
66
+ class ImageNet(torchvision.datasets.ImageFolder):
67
+ def __init__(self, root, split='train', aug='randcrop', img_size=256):
68
+ super().__init__(osp.join(root, split))
69
+ if not 'cache' in aug:
70
+ self.transform = vae_transforms(split, aug=aug, img_size=img_size)
71
+ else:
72
+ self.transform = cached_transforms(aug=aug, img_size=img_size)
semanticist/utils/device_utils.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ def configure_compute_backend():
4
+ """Configure PyTorch compute backend settings for CUDA."""
5
+ if torch.cuda.is_available():
6
+ torch.backends.cuda.matmul.allow_tf32 = True
7
+ torch.backends.cudnn.allow_tf32 = True
8
+ torch.backends.cudnn.benchmark = True
9
+ torch.backends.cudnn.deterministic = False
10
+ else:
11
+ raise ValueError("No CUDA available")
12
+
13
+ def get_device():
14
+ """Get the device to use for training."""
15
+ if torch.cuda.is_available():
16
+ return torch.device("cuda")
17
+ else:
18
+ raise ValueError("No CUDA available")
semanticist/utils/logger.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import defaultdict, deque
2
+ import datetime
3
+ import time
4
+ import torch
5
+ import torch.distributed as dist
6
+ from semanticist.engine.trainer_utils import is_dist_avail_and_initialized, is_main_process
7
+ from semanticist.utils.device_utils import get_device
8
+
9
+ def synchronize_processes():
10
+ if torch.cuda.is_available():
11
+ torch.cuda.synchronize()
12
+ else: # do nothing
13
+ pass
14
+
15
+ def empty_cache():
16
+ if torch.cuda.is_available():
17
+ torch.cuda.empty_cache()
18
+ else: # do nothing
19
+ pass
20
+
21
+ class SmoothedValue(object):
22
+ """Track a series of values and provide access to smoothed values over a
23
+ window or the global series average.
24
+ """
25
+
26
+ def __init__(self, window_size=20, fmt=None):
27
+ if fmt is None:
28
+ fmt = "{median:.4f} ({global_avg:.4f})"
29
+ self.deque = deque(maxlen=window_size)
30
+ self.total = 0.0
31
+ self.count = 0
32
+ self.fmt = fmt
33
+
34
+ def update(self, value, n=1):
35
+ self.deque.append(value)
36
+ self.count += n
37
+ self.total += value * n
38
+
39
+ def synchronize_between_processes(self):
40
+ """
41
+ Warning: does not synchronize the deque!
42
+ """
43
+ if not is_dist_avail_and_initialized():
44
+ return
45
+ t = torch.tensor([self.count, self.total], dtype=torch.float32, device=get_device())
46
+ dist.barrier()
47
+ dist.all_reduce(t)
48
+ t = t.tolist()
49
+ self.count = int(t[0])
50
+ self.total = t[1]
51
+
52
+ @property
53
+ def median(self):
54
+ d = torch.tensor(list(self.deque))
55
+ return d.median().item()
56
+
57
+ @property
58
+ def avg(self):
59
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
60
+ return d.mean().item()
61
+
62
+ @property
63
+ def global_avg(self):
64
+ return self.total / self.count
65
+
66
+ @property
67
+ def max(self):
68
+ return max(self.deque)
69
+
70
+ @property
71
+ def value(self):
72
+ return self.deque[-1]
73
+
74
+ def __str__(self):
75
+ return self.fmt.format(
76
+ median=self.median,
77
+ avg=self.avg,
78
+ global_avg=self.global_avg,
79
+ max=self.max,
80
+ value=self.value)
81
+
82
+
83
+ class MetricLogger(object):
84
+ def __init__(self, delimiter="\t"):
85
+ self.meters = defaultdict(SmoothedValue)
86
+ self.delimiter = delimiter
87
+
88
+ def update(self, **kwargs):
89
+ for k, v in kwargs.items():
90
+ if v is None:
91
+ continue
92
+ if isinstance(v, torch.Tensor):
93
+ v = v.item()
94
+ assert isinstance(v, (float, int))
95
+ self.meters[k].update(v)
96
+
97
+ def __getattr__(self, attr):
98
+ if attr in self.meters:
99
+ return self.meters[attr]
100
+ if attr in self.__dict__:
101
+ return self.__dict__[attr]
102
+ raise AttributeError("'{}' object has no attribute '{}'".format(
103
+ type(self).__name__, attr))
104
+
105
+ def __str__(self):
106
+ loss_str = []
107
+ for name, meter in self.meters.items():
108
+ loss_str.append(
109
+ "{}: {}".format(name, str(meter))
110
+ )
111
+ return self.delimiter.join(loss_str)
112
+
113
+ def synchronize_between_processes(self):
114
+ for meter in self.meters.values():
115
+ meter.synchronize_between_processes()
116
+
117
+ def add_meter(self, name, meter):
118
+ self.meters[name] = meter
119
+
120
+ def log_every(self, iterable, print_freq, header=None):
121
+ i = 0
122
+ if not header:
123
+ header = ''
124
+ start_time = time.time()
125
+ end = time.time()
126
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
127
+ data_time = SmoothedValue(fmt='{avg:.4f}')
128
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
129
+ log_msg = [
130
+ header,
131
+ '[{0' + space_fmt + '}/{1}]',
132
+ 'eta: {eta}',
133
+ '{meters}',
134
+ 'time: {time}',
135
+ 'data: {data}'
136
+ ]
137
+ if torch.cuda.is_available():
138
+ log_msg.append('mem: {memory:.0f}')
139
+ log_msg.append("util: {util:.1f}%")
140
+ log_msg = self.delimiter.join(log_msg)
141
+ MB = 1024.0 * 1024.0
142
+ for obj in iterable:
143
+ data_time.update(time.time() - end)
144
+ yield obj
145
+ iter_time.update(time.time() - end)
146
+ if i % print_freq == 0 or i == len(iterable) - 1:
147
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
148
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
149
+ if torch.cuda.is_available():
150
+ if is_main_process():
151
+ memory = torch.cuda.max_memory_allocated()
152
+ util = torch.cuda.utilization()
153
+ print(log_msg.format(
154
+ i, len(iterable), eta=eta_string,
155
+ meters=str(self),
156
+ time=str(iter_time), data=str(data_time),
157
+ memory=memory / MB, util=util))
158
+ else:
159
+ if is_main_process():
160
+ print(log_msg.format(
161
+ i, len(iterable), eta=eta_string,
162
+ meters=str(self),
163
+ time=str(iter_time), data=str(data_time)))
164
+ i += 1
165
+ end = time.time()
166
+ total_time = time.time() - start_time
167
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
168
+ if is_main_process():
169
+ print('{} Total time: {} ({:.4f} s / it)'.format(
170
+ header, total_time_str, total_time / len(iterable)))
semanticist/utils/lr_scheduler.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from timm.scheduler.cosine_lr import CosineLRScheduler
2
+ from timm.scheduler.step_lr import StepLRScheduler
3
+
4
+ def build_scheduler(optimizer, n_epoch, n_iter_per_epoch, lr_min=0, warmup_steps=0, warmup_lr_init=0, decay_steps=None, cosine_lr=True):
5
+ if decay_steps is None:
6
+ decay_steps = n_epoch * n_iter_per_epoch
7
+
8
+ if cosine_lr:
9
+ scheduler = CosineLRScheduler(optimizer, t_initial=decay_steps, lr_min=lr_min, warmup_t=warmup_steps, warmup_lr_init=warmup_lr_init,
10
+ cycle_limit=1, t_in_epochs=False, warmup_prefix=True)
11
+ else:
12
+ scheduler = StepLRScheduler(optimizer, decay_t=decay_steps, warmup_t=warmup_steps, warmup_lr_init=warmup_lr_init,
13
+ t_in_epochs=False, warmup_prefix=True)
14
+
15
+ return scheduler
semanticist/utils/transform.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import PIL
2
+ import torchvision.transforms as T
3
+
4
+ def pair(t):
5
+ return t if isinstance(t, tuple) else (t, t)
6
+
7
+ def stage1_transform(img_size=256, is_train=True, scale=0.8):
8
+
9
+ resize = pair(int(img_size/scale))
10
+ t = []
11
+ t.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC))
12
+ if is_train:
13
+ t.append(T.RandomCrop(img_size))
14
+ t.append(T.RandomHorizontalFlip(p=0.5))
15
+ else:
16
+ t.append(T.CenterCrop(img_size))
17
+
18
+ t.append(T.ToTensor())
19
+ t.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))),
20
+
21
+ return T.Compose(t)
22
+
23
+ def stage2_transform(img_size=256, is_train=True, scale=0.8):
24
+ resize = pair(int(img_size/scale))
25
+ t = []
26
+ t.append(T.Resize(resize, interpolation=PIL.Image.BICUBIC))
27
+ if is_train:
28
+ t.append(T.RandomCrop(img_size))
29
+ else:
30
+ t.append(T.CenterCrop(img_size))
31
+
32
+ t.append(T.ToTensor())
33
+ t.append(T.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))),
34
+
35
+ return T.Compose(t)