Spaces:
Runtime error
Runtime error
gaparmar
commited on
Commit
·
a5f38fd
1
Parent(s):
13ed5cd
gamma
Browse files- app.py +1 -1
- src/model.py +46 -1
- src/pix2pix_turbo.py +3 -46
app.py
CHANGED
|
@@ -238,7 +238,7 @@ with gr.Blocks(css="style.css") as demo:
|
|
| 238 |
prompt_temp = gr.Textbox(label="Prompt Style Template", value=styles[DEFAULT_STYLE_NAME], scale=2, max_lines=1)
|
| 239 |
|
| 240 |
with gr.Row():
|
| 241 |
-
val_r = gr.Slider(label="
|
| 242 |
seed = gr.Textbox(label="Seed", value=42, scale=1, min_width=50)
|
| 243 |
randomize_seed = gr.Button("Random", scale=1, min_width=50)
|
| 244 |
|
|
|
|
| 238 |
prompt_temp = gr.Textbox(label="Prompt Style Template", value=styles[DEFAULT_STYLE_NAME], scale=2, max_lines=1)
|
| 239 |
|
| 240 |
with gr.Row():
|
| 241 |
+
val_r = gr.Slider(label="Sketch guidance gamma: ", show_label=True, minimum=0, maximum=1, value=0.4, step=0.01, scale=3)
|
| 242 |
seed = gr.Textbox(label="Seed", value=42, scale=1, min_width=50)
|
| 243 |
randomize_seed = gr.Button("Random", scale=1, min_width=50)
|
| 244 |
|
src/model.py
CHANGED
|
@@ -10,4 +10,49 @@ def make_1step_sched():
|
|
| 10 |
noise_scheduler_1step = DDPMScheduler.from_pretrained("stabilityai/sd-turbo", subfolder="scheduler")
|
| 11 |
noise_scheduler_1step.set_timesteps(1, device="cuda")
|
| 12 |
noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
|
| 13 |
-
return noise_scheduler_1step
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
noise_scheduler_1step = DDPMScheduler.from_pretrained("stabilityai/sd-turbo", subfolder="scheduler")
|
| 11 |
noise_scheduler_1step.set_timesteps(1, device="cuda")
|
| 12 |
noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
|
| 13 |
+
return noise_scheduler_1step
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
"""The forward method of the `Encoder` class."""
|
| 17 |
+
def my_vae_encoder_fwd(self, sample):
|
| 18 |
+
sample = self.conv_in(sample)
|
| 19 |
+
l_blocks = []
|
| 20 |
+
# down
|
| 21 |
+
for down_block in self.down_blocks:
|
| 22 |
+
l_blocks.append(sample)
|
| 23 |
+
sample = down_block(sample)
|
| 24 |
+
# middle
|
| 25 |
+
sample = self.mid_block(sample)
|
| 26 |
+
sample = self.conv_norm_out(sample)
|
| 27 |
+
sample = self.conv_act(sample)
|
| 28 |
+
sample = self.conv_out(sample)
|
| 29 |
+
self.current_down_blocks = l_blocks
|
| 30 |
+
return sample
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
"""The forward method of the `Decoder` class."""
|
| 34 |
+
def my_vae_decoder_fwd(self,sample, latent_embeds = None):
|
| 35 |
+
sample = self.conv_in(sample)
|
| 36 |
+
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
| 37 |
+
# middle
|
| 38 |
+
sample = self.mid_block(sample, latent_embeds)
|
| 39 |
+
sample = sample.to(upscale_dtype)
|
| 40 |
+
if not self.ignore_skip:
|
| 41 |
+
skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4]
|
| 42 |
+
# up
|
| 43 |
+
for idx, up_block in enumerate(self.up_blocks):
|
| 44 |
+
skip_in = skip_convs[idx](self.incoming_skip_acts[::-1][idx] * self.gamma)
|
| 45 |
+
# add skip
|
| 46 |
+
sample = sample + skip_in
|
| 47 |
+
sample = up_block(sample, latent_embeds)
|
| 48 |
+
else:
|
| 49 |
+
for idx, up_block in enumerate(self.up_blocks):
|
| 50 |
+
sample = up_block(sample, latent_embeds)
|
| 51 |
+
# post-process
|
| 52 |
+
if latent_embeds is None:
|
| 53 |
+
sample = self.conv_norm_out(sample)
|
| 54 |
+
else:
|
| 55 |
+
sample = self.conv_norm_out(sample, latent_embeds)
|
| 56 |
+
sample = self.conv_act(sample)
|
| 57 |
+
sample = self.conv_out(sample)
|
| 58 |
+
return sample
|
src/pix2pix_turbo.py
CHANGED
|
@@ -11,52 +11,7 @@ from diffusers.utils.peft_utils import set_weights_and_activate_adapters
|
|
| 11 |
from peft import LoraConfig
|
| 12 |
p = "src/"
|
| 13 |
sys.path.append(p)
|
| 14 |
-
from model import make_1step_sched
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
"""The forward method of the `Encoder` class."""
|
| 18 |
-
def my_vae_encoder_fwd(self, sample):
|
| 19 |
-
sample = self.conv_in(sample)
|
| 20 |
-
l_blocks = []
|
| 21 |
-
# down
|
| 22 |
-
for down_block in self.down_blocks:
|
| 23 |
-
l_blocks.append(sample)
|
| 24 |
-
sample = down_block(sample)
|
| 25 |
-
# middle
|
| 26 |
-
sample = self.mid_block(sample)
|
| 27 |
-
sample = self.conv_norm_out(sample)
|
| 28 |
-
sample = self.conv_act(sample)
|
| 29 |
-
sample = self.conv_out(sample)
|
| 30 |
-
self.current_down_blocks = l_blocks
|
| 31 |
-
return sample
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
"""The forward method of the `Decoder` class."""
|
| 35 |
-
def my_vae_decoder_fwd(self,sample, latent_embeds = None):
|
| 36 |
-
sample = self.conv_in(sample)
|
| 37 |
-
upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
|
| 38 |
-
# middle
|
| 39 |
-
sample = self.mid_block(sample, latent_embeds)
|
| 40 |
-
sample = sample.to(upscale_dtype)
|
| 41 |
-
if not self.ignore_skip:
|
| 42 |
-
skip_convs = [self.skip_conv_1, self.skip_conv_2, self.skip_conv_3, self.skip_conv_4]
|
| 43 |
-
# up
|
| 44 |
-
for idx, up_block in enumerate(self.up_blocks):
|
| 45 |
-
skip_in = skip_convs[idx](self.incoming_skip_acts[::-1][idx])
|
| 46 |
-
# add skip
|
| 47 |
-
sample = sample + skip_in
|
| 48 |
-
sample = up_block(sample, latent_embeds)
|
| 49 |
-
else:
|
| 50 |
-
for idx, up_block in enumerate(self.up_blocks):
|
| 51 |
-
sample = up_block(sample, latent_embeds)
|
| 52 |
-
# post-process
|
| 53 |
-
if latent_embeds is None:
|
| 54 |
-
sample = self.conv_norm_out(sample)
|
| 55 |
-
else:
|
| 56 |
-
sample = self.conv_norm_out(sample, latent_embeds)
|
| 57 |
-
sample = self.conv_act(sample)
|
| 58 |
-
sample = self.conv_out(sample)
|
| 59 |
-
return sample
|
| 60 |
|
| 61 |
|
| 62 |
class TwinConv(torch.nn.Module):
|
|
@@ -151,6 +106,7 @@ class Pix2Pix_Turbo(torch.nn.Module):
|
|
| 151 |
unet.eval()
|
| 152 |
vae.eval()
|
| 153 |
self.unet, self.vae = unet, vae
|
|
|
|
| 154 |
self.timesteps = torch.tensor([999], device="cuda").long()
|
| 155 |
|
| 156 |
|
|
@@ -177,5 +133,6 @@ class Pix2Pix_Turbo(torch.nn.Module):
|
|
| 177 |
self.unet.conv_in.r = None
|
| 178 |
x_denoised = self.sched.step(unet_output, self.timesteps, unet_input, return_dict=True).prev_sample
|
| 179 |
self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks
|
|
|
|
| 180 |
output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor ).sample).clamp(-1,1)
|
| 181 |
return output_image
|
|
|
|
| 11 |
from peft import LoraConfig
|
| 12 |
p = "src/"
|
| 13 |
sys.path.append(p)
|
| 14 |
+
from model import make_1step_sched, my_vae_encoder_fwd, my_vae_decoder_fwd
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
|
| 16 |
|
| 17 |
class TwinConv(torch.nn.Module):
|
|
|
|
| 106 |
unet.eval()
|
| 107 |
vae.eval()
|
| 108 |
self.unet, self.vae = unet, vae
|
| 109 |
+
self.vae.decoder.gamma = 1
|
| 110 |
self.timesteps = torch.tensor([999], device="cuda").long()
|
| 111 |
|
| 112 |
|
|
|
|
| 133 |
self.unet.conv_in.r = None
|
| 134 |
x_denoised = self.sched.step(unet_output, self.timesteps, unet_input, return_dict=True).prev_sample
|
| 135 |
self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks
|
| 136 |
+
self.vae.decoder.gamma = r
|
| 137 |
output_image = (self.vae.decode(x_denoised / self.vae.config.scaling_factor ).sample).clamp(-1,1)
|
| 138 |
return output_image
|