Spaces:
Running
on
Zero
Running
on
Zero
fix
Browse files- app.py +31 -30
- vae/config.json +1 -2
- vae/nextstep_ae.py +39 -72
app.py
CHANGED
@@ -3,9 +3,8 @@ import numpy as np
|
|
3 |
import random
|
4 |
import spaces
|
5 |
from PIL import Image
|
6 |
-
|
7 |
-
# import spaces #[uncomment to use ZeroGPU]
|
8 |
import torch
|
|
|
9 |
|
10 |
from transformers import AutoTokenizer, AutoModel
|
11 |
from models.gen_pipeline import NextStepPipeline
|
@@ -15,8 +14,15 @@ HF_HUB = "stepfun-ai/NextStep-1-Large"
|
|
15 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
16 |
|
17 |
tokenizer = AutoTokenizer.from_pretrained(HF_HUB, local_files_only=False, trust_remote_code=True)
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
MAX_SEED = np.iinfo(np.int16).max
|
22 |
MAX_IMAGE_SIZE = 512
|
@@ -30,8 +36,6 @@ def infer(
|
|
30 |
seed=0,
|
31 |
width=512,
|
32 |
height=512,
|
33 |
-
#text_cfg=7.5,
|
34 |
-
#img_cfg=1.0,
|
35 |
num_inference_steps=28,
|
36 |
positive_prompt=DEFAULT_POSITIVE_PROMPT,
|
37 |
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
|
@@ -40,21 +44,23 @@ def infer(
|
|
40 |
if prompt in [None, ""]:
|
41 |
gr.Warning("⚠️ Please enter a prompt!")
|
42 |
return None
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
|
|
|
|
58 |
|
59 |
return image[0]
|
60 |
|
@@ -114,7 +120,7 @@ with gr.Blocks(css=css) as demo:
|
|
114 |
step=1,
|
115 |
value=28,
|
116 |
)
|
117 |
-
|
118 |
with gr.Row():
|
119 |
width = gr.Slider(
|
120 |
label="Width",
|
@@ -132,9 +138,7 @@ with gr.Blocks(css=css) as demo:
|
|
132 |
)
|
133 |
|
134 |
with gr.Row():
|
135 |
-
result_1 = gr.Image(label="Result 1", show_label=False, container=True, height=
|
136 |
-
|
137 |
-
#gr.Examples(examples=examples, inputs=[prompt, ref])
|
138 |
|
139 |
def show_result():
|
140 |
return gr.update(visible=True)
|
@@ -147,15 +151,13 @@ with gr.Blocks(css=css) as demo:
|
|
147 |
seed,
|
148 |
width,
|
149 |
height,
|
150 |
-
#text_cfg,
|
151 |
-
#img_cfg,
|
152 |
num_inference_steps,
|
153 |
positive_prompt,
|
154 |
negative_prompt,
|
155 |
],
|
156 |
outputs=[result_1],
|
157 |
)
|
158 |
-
|
159 |
cancel_button.click(
|
160 |
fn=None,
|
161 |
inputs=None,
|
@@ -169,6 +171,5 @@ with gr.Blocks(css=css) as demo:
|
|
169 |
outputs=[result_1],
|
170 |
)
|
171 |
|
172 |
-
|
173 |
if __name__ == "__main__":
|
174 |
-
demo.launch()
|
|
|
3 |
import random
|
4 |
import spaces
|
5 |
from PIL import Image
|
|
|
|
|
6 |
import torch
|
7 |
+
from torch.amp import autocast
|
8 |
|
9 |
from transformers import AutoTokenizer, AutoModel
|
10 |
from models.gen_pipeline import NextStepPipeline
|
|
|
14 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
15 |
|
16 |
tokenizer = AutoTokenizer.from_pretrained(HF_HUB, local_files_only=False, trust_remote_code=True)
|
17 |
+
|
18 |
+
model = AutoModel.from_pretrained(
|
19 |
+
HF_HUB,
|
20 |
+
local_files_only=False,
|
21 |
+
trust_remote_code=True,
|
22 |
+
torch_dtype=torch.bfloat16,
|
23 |
+
).to(device)
|
24 |
+
|
25 |
+
pipeline = NextStepPipeline(tokenizer=tokenizer, model=model).to(device=device, dtype=torch.bfloat16)
|
26 |
|
27 |
MAX_SEED = np.iinfo(np.int16).max
|
28 |
MAX_IMAGE_SIZE = 512
|
|
|
36 |
seed=0,
|
37 |
width=512,
|
38 |
height=512,
|
|
|
|
|
39 |
num_inference_steps=28,
|
40 |
positive_prompt=DEFAULT_POSITIVE_PROMPT,
|
41 |
negative_prompt=DEFAULT_NEGATIVE_PROMPT,
|
|
|
44 |
if prompt in [None, ""]:
|
45 |
gr.Warning("⚠️ Please enter a prompt!")
|
46 |
return None
|
47 |
+
|
48 |
+
with autocast(device_type=("cuda" if device == "cuda" else "cpu"), dtype=torch.bfloat16):
|
49 |
+
image = pipeline.generate_image(
|
50 |
+
prompt,
|
51 |
+
hw=(height, width),
|
52 |
+
num_images_per_caption=1,
|
53 |
+
positive_prompt=positive_prompt,
|
54 |
+
negative_prompt=negative_prompt,
|
55 |
+
cfg=7.5,
|
56 |
+
cfg_img=1.0,
|
57 |
+
cfg_schedule="constant",
|
58 |
+
use_norm=False,
|
59 |
+
num_sampling_steps=num_inference_steps,
|
60 |
+
timesteps_shift=1.0,
|
61 |
+
seed=seed,
|
62 |
+
progress=True,
|
63 |
+
)
|
64 |
|
65 |
return image[0]
|
66 |
|
|
|
120 |
step=1,
|
121 |
value=28,
|
122 |
)
|
123 |
+
|
124 |
with gr.Row():
|
125 |
width = gr.Slider(
|
126 |
label="Width",
|
|
|
138 |
)
|
139 |
|
140 |
with gr.Row():
|
141 |
+
result_1 = gr.Image(label="Result 1", show_label=False, container=True, height=MAX_IMAGE_SIZE, visible=False)
|
|
|
|
|
142 |
|
143 |
def show_result():
|
144 |
return gr.update(visible=True)
|
|
|
151 |
seed,
|
152 |
width,
|
153 |
height,
|
|
|
|
|
154 |
num_inference_steps,
|
155 |
positive_prompt,
|
156 |
negative_prompt,
|
157 |
],
|
158 |
outputs=[result_1],
|
159 |
)
|
160 |
+
|
161 |
cancel_button.click(
|
162 |
fn=None,
|
163 |
inputs=None,
|
|
|
171 |
outputs=[result_1],
|
172 |
)
|
173 |
|
|
|
174 |
if __name__ == "__main__":
|
175 |
+
demo.launch()
|
vae/config.json
CHANGED
@@ -9,7 +9,6 @@
|
|
9 |
"shift_factor": 0,
|
10 |
"scaling_factor": 1,
|
11 |
"deterministic": true,
|
12 |
-
"
|
13 |
-
"norm_level": "channel",
|
14 |
"psz": 1
|
15 |
}
|
|
|
9 |
"shift_factor": 0,
|
10 |
"scaling_factor": 1,
|
11 |
"deterministic": true,
|
12 |
+
"encoder_norm": true,
|
|
|
13 |
"psz": 1
|
14 |
}
|
vae/nextstep_ae.py
CHANGED
@@ -2,7 +2,6 @@ import os
|
|
2 |
import json
|
3 |
import inspect
|
4 |
from dataclasses import dataclass, field, asdict
|
5 |
-
from typing import Literal
|
6 |
from loguru import logger
|
7 |
from omegaconf import OmegaConf
|
8 |
from tabulate import tabulate
|
@@ -10,6 +9,7 @@ from einops import rearrange
|
|
10 |
|
11 |
import torch
|
12 |
import torch.nn as nn
|
|
|
13 |
from torch import Tensor
|
14 |
from torch.utils.checkpoint import checkpoint
|
15 |
|
@@ -17,7 +17,7 @@ from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDis
|
|
17 |
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
18 |
|
19 |
from utils.misc import LargeInt
|
20 |
-
from utils.model_utils import
|
21 |
from utils.compile_utils import smart_compile
|
22 |
|
23 |
|
@@ -33,8 +33,7 @@ class AutoEncoderParams:
|
|
33 |
scaling_factor: float = 0.3611
|
34 |
shift_factor: float = 0.1159
|
35 |
deterministic: bool = False
|
36 |
-
|
37 |
-
norm_level: Literal["latent", "channel"] = "latent"
|
38 |
psz: int | None = None
|
39 |
|
40 |
|
@@ -306,6 +305,14 @@ class Decoder(nn.Module):
|
|
306 |
return h
|
307 |
|
308 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
class AutoencoderKL(nn.Module):
|
310 |
def __init__(self, params: AutoEncoderParams):
|
311 |
super().__init__()
|
@@ -333,19 +340,8 @@ class AutoencoderKL(nn.Module):
|
|
333 |
z_channels=params.z_channels,
|
334 |
)
|
335 |
|
|
|
336 |
self.psz = params.psz
|
337 |
-
# if self.psz is not None:
|
338 |
-
# logger.warning("psz has been deprecated, this is only used for hack's vae")
|
339 |
-
|
340 |
-
if params.norm_fn is None:
|
341 |
-
self.norm_fn = identity
|
342 |
-
elif params.norm_fn == "layer_norm":
|
343 |
-
self.norm_fn = layer_norm
|
344 |
-
elif params.norm_fn == "rms_norm":
|
345 |
-
self.norm_fn = rms_norm
|
346 |
-
else:
|
347 |
-
raise ValueError(f"Invalid norm_fn: {params.norm_fn}")
|
348 |
-
self.norm_level = params.norm_level
|
349 |
|
350 |
self.apply(self._init_weights)
|
351 |
|
@@ -420,18 +416,17 @@ class AutoencoderKL(nn.Module):
|
|
420 |
def encode(self, x: torch.Tensor, return_dict: bool = True):
|
421 |
moments = self.encoder(x)
|
422 |
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
431 |
-
|
432 |
-
|
433 |
-
|
434 |
-
moments = torch.cat([mean, logvar], dim=1).contiguous()
|
435 |
|
436 |
posterior = DiagonalGaussianDistribution(moments, deterministic=self.params.deterministic)
|
437 |
|
@@ -448,14 +443,7 @@ class AutoencoderKL(nn.Module):
|
|
448 |
|
449 |
return DecoderOutput(sample=dec)
|
450 |
|
451 |
-
def forward(
|
452 |
-
self,
|
453 |
-
input,
|
454 |
-
sample_posterior=True,
|
455 |
-
noise_strength=0.0,
|
456 |
-
interpolative_noise=False,
|
457 |
-
t_dist: Literal["uniform", "logitnormal"] = "logitnormal",
|
458 |
-
):
|
459 |
posterior = self.encode(input).latent_dist
|
460 |
z = posterior.sample() if sample_posterior else posterior.mode()
|
461 |
if noise_strength > 0.0:
|
@@ -463,46 +451,25 @@ class AutoencoderKL(nn.Module):
|
|
463 |
z = z + p.sample((z.shape[0],)).reshape(-1, 1, 1, 1).to(z.device) * randn_tensor(
|
464 |
z.shape, device=z.device, dtype=z.dtype
|
465 |
)
|
466 |
-
if interpolative_noise:
|
467 |
-
z = self.patchify(z)
|
468 |
-
bsz, c, h, w = z.shape
|
469 |
-
z = z.permute(0, 2, 3, 1) # [bsz, h, w, c]
|
470 |
-
z = z.reshape(-1, c) # [bsz * h * w, c]
|
471 |
-
|
472 |
-
if t_dist == "logitnormal":
|
473 |
-
u = torch.normal(mean=0.0, std=1.0, size=(z.shape[0],))
|
474 |
-
t = (1 / (1 + torch.exp(-u))).to(z)
|
475 |
-
elif t_dist == "uniform":
|
476 |
-
t = torch.randn((z.shape[0],)).to(z)
|
477 |
-
else:
|
478 |
-
raise ValueError(f"Invalid t_dist: {t_dist}")
|
479 |
-
|
480 |
-
noise = torch.randn_like(z)
|
481 |
-
z = expand_t(t, z) * z + (1 - expand_t(t, z)) * noise
|
482 |
-
|
483 |
-
z = z.reshape(bsz, h, w, c).permute(0, 3, 1, 2)
|
484 |
-
z = self.unpatchify(z)
|
485 |
-
|
486 |
dec = self.decode(z).sample
|
487 |
return dec, posterior
|
488 |
|
489 |
@classmethod
|
490 |
-
def from_pretrained(cls,
|
491 |
-
config_path =
|
492 |
-
ckpt_path =
|
493 |
-
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
kwargs = config
|
506 |
|
507 |
# Filter out kwargs that are not in AutoEncoderParams
|
508 |
# This ensures we only pass parameters that the model can accept
|
|
|
2 |
import json
|
3 |
import inspect
|
4 |
from dataclasses import dataclass, field, asdict
|
|
|
5 |
from loguru import logger
|
6 |
from omegaconf import OmegaConf
|
7 |
from tabulate import tabulate
|
|
|
9 |
|
10 |
import torch
|
11 |
import torch.nn as nn
|
12 |
+
import torch.nn.functional as F
|
13 |
from torch import Tensor
|
14 |
from torch.utils.checkpoint import checkpoint
|
15 |
|
|
|
17 |
from diffusers.models.modeling_outputs import AutoencoderKLOutput
|
18 |
|
19 |
from utils.misc import LargeInt
|
20 |
+
from utils.model_utils import randn_tensor
|
21 |
from utils.compile_utils import smart_compile
|
22 |
|
23 |
|
|
|
33 |
scaling_factor: float = 0.3611
|
34 |
shift_factor: float = 0.1159
|
35 |
deterministic: bool = False
|
36 |
+
encoder_norm: bool = False
|
|
|
37 |
psz: int | None = None
|
38 |
|
39 |
|
|
|
305 |
return h
|
306 |
|
307 |
|
308 |
+
def layer_norm_2d(input: torch.Tensor, normalized_shape: torch.Size, eps: float = 1e-6) -> torch.Tensor:
|
309 |
+
# input.shape = (bsz, c, h, w)
|
310 |
+
_input = input.permute(0, 2, 3, 1)
|
311 |
+
_input = F.layer_norm(_input, normalized_shape, None, None, eps)
|
312 |
+
_input = _input.permute(0, 3, 1, 2)
|
313 |
+
return _input
|
314 |
+
|
315 |
+
|
316 |
class AutoencoderKL(nn.Module):
|
317 |
def __init__(self, params: AutoEncoderParams):
|
318 |
super().__init__()
|
|
|
340 |
z_channels=params.z_channels,
|
341 |
)
|
342 |
|
343 |
+
self.encoder_norm = params.encoder_norm
|
344 |
self.psz = params.psz
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
345 |
|
346 |
self.apply(self._init_weights)
|
347 |
|
|
|
416 |
def encode(self, x: torch.Tensor, return_dict: bool = True):
|
417 |
moments = self.encoder(x)
|
418 |
|
419 |
+
mean, logvar = torch.chunk(moments, 2, dim=1)
|
420 |
+
if self.psz is not None:
|
421 |
+
mean = self.patchify(mean)
|
422 |
+
|
423 |
+
if self.encoder_norm:
|
424 |
+
mean = layer_norm_2d(mean, mean.size()[-1:])
|
425 |
+
|
426 |
+
if self.psz is not None:
|
427 |
+
mean = self.unpatchify(mean)
|
428 |
+
|
429 |
+
moments = torch.cat([mean, logvar], dim=1).contiguous()
|
|
|
430 |
|
431 |
posterior = DiagonalGaussianDistribution(moments, deterministic=self.params.deterministic)
|
432 |
|
|
|
443 |
|
444 |
return DecoderOutput(sample=dec)
|
445 |
|
446 |
+
def forward(self, input, sample_posterior=True, noise_strength=0.0):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
447 |
posterior = self.encode(input).latent_dist
|
448 |
z = posterior.sample() if sample_posterior else posterior.mode()
|
449 |
if noise_strength > 0.0:
|
|
|
451 |
z = z + p.sample((z.shape[0],)).reshape(-1, 1, 1, 1).to(z.device) * randn_tensor(
|
452 |
z.shape, device=z.device, dtype=z.dtype
|
453 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
dec = self.decode(z).sample
|
455 |
return dec, posterior
|
456 |
|
457 |
@classmethod
|
458 |
+
def from_pretrained(cls, model_path, **kwargs):
|
459 |
+
config_path = os.path.join(model_path, "config.json")
|
460 |
+
ckpt_path = os.path.join(model_path, "checkpoint.pt")
|
461 |
+
|
462 |
+
if not os.path.isdir(model_path) or not os.path.isfile(config_path) or not os.path.isfile(ckpt_path):
|
463 |
+
raise ValueError(
|
464 |
+
f"Invalid model path: {model_path}. The path should contain both config.json and checkpoint.pt files."
|
465 |
+
)
|
466 |
+
|
467 |
+
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
468 |
+
|
469 |
+
with open(config_path, "r") as f:
|
470 |
+
config: dict = json.load(f)
|
471 |
+
config.update(kwargs)
|
472 |
+
kwargs = config
|
|
|
473 |
|
474 |
# Filter out kwargs that are not in AutoEncoderParams
|
475 |
# This ensures we only pass parameters that the model can accept
|