bmarci commited on
Commit
b81a188
·
1 Parent(s): 996db64
Files changed (3) hide show
  1. app.py +31 -30
  2. vae/config.json +1 -2
  3. vae/nextstep_ae.py +39 -72
app.py CHANGED
@@ -3,9 +3,8 @@ import numpy as np
3
  import random
4
  import spaces
5
  from PIL import Image
6
-
7
- # import spaces #[uncomment to use ZeroGPU]
8
  import torch
 
9
 
10
  from transformers import AutoTokenizer, AutoModel
11
  from models.gen_pipeline import NextStepPipeline
@@ -15,8 +14,15 @@ HF_HUB = "stepfun-ai/NextStep-1-Large"
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
  tokenizer = AutoTokenizer.from_pretrained(HF_HUB, local_files_only=False, trust_remote_code=True)
18
- model = AutoModel.from_pretrained(HF_HUB, local_files_only=False, trust_remote_code=True)
19
- pipeline = NextStepPipeline(tokenizer=tokenizer, model=model).to(device=device)
 
 
 
 
 
 
 
20
 
21
  MAX_SEED = np.iinfo(np.int16).max
22
  MAX_IMAGE_SIZE = 512
@@ -30,8 +36,6 @@ def infer(
30
  seed=0,
31
  width=512,
32
  height=512,
33
- #text_cfg=7.5,
34
- #img_cfg=1.0,
35
  num_inference_steps=28,
36
  positive_prompt=DEFAULT_POSITIVE_PROMPT,
37
  negative_prompt=DEFAULT_NEGATIVE_PROMPT,
@@ -40,21 +44,23 @@ def infer(
40
  if prompt in [None, ""]:
41
  gr.Warning("⚠️ Please enter a prompt!")
42
  return None
43
- image = pipeline.generate_image(
44
- editing_caption,
45
- hw=(height, width),
46
- num_images_per_caption=1,
47
- positive_prompt=positive_prompt,
48
- negative_prompt=negative_prompt,
49
- cfg=7.5,
50
- cfg_img=1.0,
51
- cfg_schedule="constant",
52
- use_norm=False,
53
- num_sampling_steps=num_inference_steps,
54
- timesteps_shift=1.0,
55
- seed=seed,
56
- progress=True,
57
- )
 
 
58
 
59
  return image[0]
60
 
@@ -114,7 +120,7 @@ with gr.Blocks(css=css) as demo:
114
  step=1,
115
  value=28,
116
  )
117
-
118
  with gr.Row():
119
  width = gr.Slider(
120
  label="Width",
@@ -132,9 +138,7 @@ with gr.Blocks(css=css) as demo:
132
  )
133
 
134
  with gr.Row():
135
- result_1 = gr.Image(label="Result 1", show_label=False, container=True, height=400, visible=False)
136
-
137
- #gr.Examples(examples=examples, inputs=[prompt, ref])
138
 
139
  def show_result():
140
  return gr.update(visible=True)
@@ -147,15 +151,13 @@ with gr.Blocks(css=css) as demo:
147
  seed,
148
  width,
149
  height,
150
- #text_cfg,
151
- #img_cfg,
152
  num_inference_steps,
153
  positive_prompt,
154
  negative_prompt,
155
  ],
156
  outputs=[result_1],
157
  )
158
-
159
  cancel_button.click(
160
  fn=None,
161
  inputs=None,
@@ -169,6 +171,5 @@ with gr.Blocks(css=css) as demo:
169
  outputs=[result_1],
170
  )
171
 
172
-
173
  if __name__ == "__main__":
174
- demo.launch()
 
3
  import random
4
  import spaces
5
  from PIL import Image
 
 
6
  import torch
7
+ from torch.amp import autocast
8
 
9
  from transformers import AutoTokenizer, AutoModel
10
  from models.gen_pipeline import NextStepPipeline
 
14
  device = "cuda" if torch.cuda.is_available() else "cpu"
15
 
16
  tokenizer = AutoTokenizer.from_pretrained(HF_HUB, local_files_only=False, trust_remote_code=True)
17
+
18
+ model = AutoModel.from_pretrained(
19
+ HF_HUB,
20
+ local_files_only=False,
21
+ trust_remote_code=True,
22
+ torch_dtype=torch.bfloat16,
23
+ ).to(device)
24
+
25
+ pipeline = NextStepPipeline(tokenizer=tokenizer, model=model).to(device=device, dtype=torch.bfloat16)
26
 
27
  MAX_SEED = np.iinfo(np.int16).max
28
  MAX_IMAGE_SIZE = 512
 
36
  seed=0,
37
  width=512,
38
  height=512,
 
 
39
  num_inference_steps=28,
40
  positive_prompt=DEFAULT_POSITIVE_PROMPT,
41
  negative_prompt=DEFAULT_NEGATIVE_PROMPT,
 
44
  if prompt in [None, ""]:
45
  gr.Warning("⚠️ Please enter a prompt!")
46
  return None
47
+
48
+ with autocast(device_type=("cuda" if device == "cuda" else "cpu"), dtype=torch.bfloat16):
49
+ image = pipeline.generate_image(
50
+ prompt,
51
+ hw=(height, width),
52
+ num_images_per_caption=1,
53
+ positive_prompt=positive_prompt,
54
+ negative_prompt=negative_prompt,
55
+ cfg=7.5,
56
+ cfg_img=1.0,
57
+ cfg_schedule="constant",
58
+ use_norm=False,
59
+ num_sampling_steps=num_inference_steps,
60
+ timesteps_shift=1.0,
61
+ seed=seed,
62
+ progress=True,
63
+ )
64
 
65
  return image[0]
66
 
 
120
  step=1,
121
  value=28,
122
  )
123
+
124
  with gr.Row():
125
  width = gr.Slider(
126
  label="Width",
 
138
  )
139
 
140
  with gr.Row():
141
+ result_1 = gr.Image(label="Result 1", show_label=False, container=True, height=MAX_IMAGE_SIZE, visible=False)
 
 
142
 
143
  def show_result():
144
  return gr.update(visible=True)
 
151
  seed,
152
  width,
153
  height,
 
 
154
  num_inference_steps,
155
  positive_prompt,
156
  negative_prompt,
157
  ],
158
  outputs=[result_1],
159
  )
160
+
161
  cancel_button.click(
162
  fn=None,
163
  inputs=None,
 
171
  outputs=[result_1],
172
  )
173
 
 
174
  if __name__ == "__main__":
175
+ demo.launch()
vae/config.json CHANGED
@@ -9,7 +9,6 @@
9
  "shift_factor": 0,
10
  "scaling_factor": 1,
11
  "deterministic": true,
12
- "norm_fn": "layer_norm",
13
- "norm_level": "channel",
14
  "psz": 1
15
  }
 
9
  "shift_factor": 0,
10
  "scaling_factor": 1,
11
  "deterministic": true,
12
+ "encoder_norm": true,
 
13
  "psz": 1
14
  }
vae/nextstep_ae.py CHANGED
@@ -2,7 +2,6 @@ import os
2
  import json
3
  import inspect
4
  from dataclasses import dataclass, field, asdict
5
- from typing import Literal
6
  from loguru import logger
7
  from omegaconf import OmegaConf
8
  from tabulate import tabulate
@@ -10,6 +9,7 @@ from einops import rearrange
10
 
11
  import torch
12
  import torch.nn as nn
 
13
  from torch import Tensor
14
  from torch.utils.checkpoint import checkpoint
15
 
@@ -17,7 +17,7 @@ from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDis
17
  from diffusers.models.modeling_outputs import AutoencoderKLOutput
18
 
19
  from utils.misc import LargeInt
20
- from utils.model_utils import identity, rms_norm, layer_norm, randn_tensor, expand_t
21
  from utils.compile_utils import smart_compile
22
 
23
 
@@ -33,8 +33,7 @@ class AutoEncoderParams:
33
  scaling_factor: float = 0.3611
34
  shift_factor: float = 0.1159
35
  deterministic: bool = False
36
- norm_fn: Literal["layer_norm", "rms_norm"] | None = None
37
- norm_level: Literal["latent", "channel"] = "latent"
38
  psz: int | None = None
39
 
40
 
@@ -306,6 +305,14 @@ class Decoder(nn.Module):
306
  return h
307
 
308
 
 
 
 
 
 
 
 
 
309
  class AutoencoderKL(nn.Module):
310
  def __init__(self, params: AutoEncoderParams):
311
  super().__init__()
@@ -333,19 +340,8 @@ class AutoencoderKL(nn.Module):
333
  z_channels=params.z_channels,
334
  )
335
 
 
336
  self.psz = params.psz
337
- # if self.psz is not None:
338
- # logger.warning("psz has been deprecated, this is only used for hack's vae")
339
-
340
- if params.norm_fn is None:
341
- self.norm_fn = identity
342
- elif params.norm_fn == "layer_norm":
343
- self.norm_fn = layer_norm
344
- elif params.norm_fn == "rms_norm":
345
- self.norm_fn = rms_norm
346
- else:
347
- raise ValueError(f"Invalid norm_fn: {params.norm_fn}")
348
- self.norm_level = params.norm_level
349
 
350
  self.apply(self._init_weights)
351
 
@@ -420,18 +416,17 @@ class AutoencoderKL(nn.Module):
420
  def encode(self, x: torch.Tensor, return_dict: bool = True):
421
  moments = self.encoder(x)
422
 
423
- if self.norm_fn is not None:
424
- mean, logvar = torch.chunk(moments, 2, dim=1)
425
- if self.psz is not None: # HACK
426
- mean = self.patchify(mean)
427
- if self.norm_level == "latent":
428
- mean = self.norm_fn(mean, mean.size()[1:])
429
- elif self.norm_level == "channel":
430
- mean = mean.permute(0, 2, 3, 1) # [bsz, c, h, w] --> [bsz, h, w, c]
431
- mean = self.norm_fn(mean, mean.size()[-1:]).permute(0, 3, 1, 2) # [bsz, h, w, c] --> [bsz, c, h, w]
432
- if self.psz is not None: # HACK
433
- mean = self.unpatchify(mean)
434
- moments = torch.cat([mean, logvar], dim=1).contiguous()
435
 
436
  posterior = DiagonalGaussianDistribution(moments, deterministic=self.params.deterministic)
437
 
@@ -448,14 +443,7 @@ class AutoencoderKL(nn.Module):
448
 
449
  return DecoderOutput(sample=dec)
450
 
451
- def forward(
452
- self,
453
- input,
454
- sample_posterior=True,
455
- noise_strength=0.0,
456
- interpolative_noise=False,
457
- t_dist: Literal["uniform", "logitnormal"] = "logitnormal",
458
- ):
459
  posterior = self.encode(input).latent_dist
460
  z = posterior.sample() if sample_posterior else posterior.mode()
461
  if noise_strength > 0.0:
@@ -463,46 +451,25 @@ class AutoencoderKL(nn.Module):
463
  z = z + p.sample((z.shape[0],)).reshape(-1, 1, 1, 1).to(z.device) * randn_tensor(
464
  z.shape, device=z.device, dtype=z.dtype
465
  )
466
- if interpolative_noise:
467
- z = self.patchify(z)
468
- bsz, c, h, w = z.shape
469
- z = z.permute(0, 2, 3, 1) # [bsz, h, w, c]
470
- z = z.reshape(-1, c) # [bsz * h * w, c]
471
-
472
- if t_dist == "logitnormal":
473
- u = torch.normal(mean=0.0, std=1.0, size=(z.shape[0],))
474
- t = (1 / (1 + torch.exp(-u))).to(z)
475
- elif t_dist == "uniform":
476
- t = torch.randn((z.shape[0],)).to(z)
477
- else:
478
- raise ValueError(f"Invalid t_dist: {t_dist}")
479
-
480
- noise = torch.randn_like(z)
481
- z = expand_t(t, z) * z + (1 - expand_t(t, z)) * noise
482
-
483
- z = z.reshape(bsz, h, w, c).permute(0, 3, 1, 2)
484
- z = self.unpatchify(z)
485
-
486
  dec = self.decode(z).sample
487
  return dec, posterior
488
 
489
  @classmethod
490
- def from_pretrained(cls, pretrained_model_name_or_path: str = "flux-vae", **kwargs):
491
- config_path = None
492
- ckpt_path = pretrained_model_name_or_path
493
- if ckpt_path is not None and os.path.isdir(ckpt_path):
494
- config_path = os.path.join(ckpt_path, "config.json")
495
- ckpt_path = os.path.join(ckpt_path, "checkpoint.pt")
496
- state_dict = torch.load(ckpt_path, map_location="cpu") if ckpt_path is not None else None
497
-
498
- if kwargs is None:
499
- kwargs = {}
500
-
501
- if config_path is not None:
502
- with open(config_path, "r") as f:
503
- config: dict = json.load(f)
504
- config.update(kwargs)
505
- kwargs = config
506
 
507
  # Filter out kwargs that are not in AutoEncoderParams
508
  # This ensures we only pass parameters that the model can accept
 
2
  import json
3
  import inspect
4
  from dataclasses import dataclass, field, asdict
 
5
  from loguru import logger
6
  from omegaconf import OmegaConf
7
  from tabulate import tabulate
 
9
 
10
  import torch
11
  import torch.nn as nn
12
+ import torch.nn.functional as F
13
  from torch import Tensor
14
  from torch.utils.checkpoint import checkpoint
15
 
 
17
  from diffusers.models.modeling_outputs import AutoencoderKLOutput
18
 
19
  from utils.misc import LargeInt
20
+ from utils.model_utils import randn_tensor
21
  from utils.compile_utils import smart_compile
22
 
23
 
 
33
  scaling_factor: float = 0.3611
34
  shift_factor: float = 0.1159
35
  deterministic: bool = False
36
+ encoder_norm: bool = False
 
37
  psz: int | None = None
38
 
39
 
 
305
  return h
306
 
307
 
308
+ def layer_norm_2d(input: torch.Tensor, normalized_shape: torch.Size, eps: float = 1e-6) -> torch.Tensor:
309
+ # input.shape = (bsz, c, h, w)
310
+ _input = input.permute(0, 2, 3, 1)
311
+ _input = F.layer_norm(_input, normalized_shape, None, None, eps)
312
+ _input = _input.permute(0, 3, 1, 2)
313
+ return _input
314
+
315
+
316
  class AutoencoderKL(nn.Module):
317
  def __init__(self, params: AutoEncoderParams):
318
  super().__init__()
 
340
  z_channels=params.z_channels,
341
  )
342
 
343
+ self.encoder_norm = params.encoder_norm
344
  self.psz = params.psz
 
 
 
 
 
 
 
 
 
 
 
 
345
 
346
  self.apply(self._init_weights)
347
 
 
416
  def encode(self, x: torch.Tensor, return_dict: bool = True):
417
  moments = self.encoder(x)
418
 
419
+ mean, logvar = torch.chunk(moments, 2, dim=1)
420
+ if self.psz is not None:
421
+ mean = self.patchify(mean)
422
+
423
+ if self.encoder_norm:
424
+ mean = layer_norm_2d(mean, mean.size()[-1:])
425
+
426
+ if self.psz is not None:
427
+ mean = self.unpatchify(mean)
428
+
429
+ moments = torch.cat([mean, logvar], dim=1).contiguous()
 
430
 
431
  posterior = DiagonalGaussianDistribution(moments, deterministic=self.params.deterministic)
432
 
 
443
 
444
  return DecoderOutput(sample=dec)
445
 
446
+ def forward(self, input, sample_posterior=True, noise_strength=0.0):
 
 
 
 
 
 
 
447
  posterior = self.encode(input).latent_dist
448
  z = posterior.sample() if sample_posterior else posterior.mode()
449
  if noise_strength > 0.0:
 
451
  z = z + p.sample((z.shape[0],)).reshape(-1, 1, 1, 1).to(z.device) * randn_tensor(
452
  z.shape, device=z.device, dtype=z.dtype
453
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
454
  dec = self.decode(z).sample
455
  return dec, posterior
456
 
457
  @classmethod
458
+ def from_pretrained(cls, model_path, **kwargs):
459
+ config_path = os.path.join(model_path, "config.json")
460
+ ckpt_path = os.path.join(model_path, "checkpoint.pt")
461
+
462
+ if not os.path.isdir(model_path) or not os.path.isfile(config_path) or not os.path.isfile(ckpt_path):
463
+ raise ValueError(
464
+ f"Invalid model path: {model_path}. The path should contain both config.json and checkpoint.pt files."
465
+ )
466
+
467
+ state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
468
+
469
+ with open(config_path, "r") as f:
470
+ config: dict = json.load(f)
471
+ config.update(kwargs)
472
+ kwargs = config
 
473
 
474
  # Filter out kwargs that are not in AutoEncoderParams
475
  # This ensures we only pass parameters that the model can accept