multimodalart HF Staff commited on
Commit
e1b4118
·
verified ·
1 Parent(s): 85cc5c2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -58
app.py CHANGED
@@ -51,11 +51,6 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024,
51
  torch.manual_seed(seed)
52
  torch.cuda.manual_seed_all(seed)
53
 
54
- # Calculate latent dimensions based on image size
55
- # Sana uses 32x downsampling factor
56
- latent_height = height // 32
57
- latent_width = width // 32
58
-
59
  with torch.inference_mode():
60
  # Encode the prompt
61
  prompt_embeds, prompt_attention_mask = pipe.encode_prompt(
@@ -63,14 +58,14 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024,
63
  device=device
64
  )
65
 
66
- # Generate initial random latents with correct dimensions
67
  init_latents = torch.randn(
68
- [1, 32, latent_height, latent_width],
69
  device=device,
70
  dtype=dtype
71
  )
72
 
73
- # Apply HyperNoise modulation with adapter enabled (single forward pass)
74
  pipe.transformer.enable_adapter_layers()
75
  modulated_latents = pipe.transformer(
76
  hidden_states=init_latents,
@@ -82,56 +77,15 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024,
82
 
83
  # Generate final image with adapter disabled
84
  pipe.transformer.disable_adapter_layers()
85
-
86
- # For SCM scheduler, we need to handle the timesteps carefully
87
- # The pipeline expects intermediate_timesteps only when num_inference_steps=2
88
- # For other values, we use the workaround from the original code
89
- if num_inference_steps == 2:
90
- # Use the default pipeline behavior for 2 steps
91
- image = pipe(
92
- latents=modulated_latents,
93
- prompt_embeds=prompt_embeds,
94
- prompt_attention_mask=prompt_attention_mask,
95
- num_inference_steps=num_inference_steps,
96
- ).images[0]
97
- else:
98
- # For num_inference_steps != 2, we need to work around the restriction
99
- # by directly calling the denoising loop
100
- pipe.scheduler.set_timesteps(
101
- num_inference_steps,
102
- device=device,
103
- timesteps=torch.linspace(1.57080, 0, num_inference_steps + 1, device=device)
104
- )
105
-
106
- # Run the denoising loop manually
107
- latents = modulated_latents
108
- for i, t in enumerate(pipe.scheduler.timesteps[:-1]):
109
- # Expand timestep to match batch dimension
110
- timestep = t.expand(latents.shape[0])
111
-
112
- # Predict noise
113
- noise_pred = pipe.transformer(
114
- hidden_states=latents,
115
- encoder_hidden_states=prompt_embeds,
116
- encoder_attention_mask=prompt_attention_mask,
117
- timestep=timestep,
118
- guidance=torch.tensor([0.0], device=device, dtype=dtype), # No guidance for denoising
119
- return_dict=False,
120
- )[0]
121
-
122
- # Compute previous noisy sample
123
- latents = pipe.scheduler.step(
124
- noise_pred,
125
- t,
126
- latents,
127
- return_dict=False
128
- )[0]
129
-
130
- # Decode latents to image
131
- latents = pipe._unpack_latents(latents, height, width, pipe.vae_scale_factor)
132
- latents = (latents / pipe.vae.scaling_factor) + pipe.vae.shift_factor
133
- image = pipe.vae.decode(latents, return_dict=False)[0]
134
- image = pipe.image_processor.postprocess(image, output_type="pil")[0]
135
 
136
  return image, seed
137
 
 
51
  torch.manual_seed(seed)
52
  torch.cuda.manual_seed_all(seed)
53
 
 
 
 
 
 
54
  with torch.inference_mode():
55
  # Encode the prompt
56
  prompt_embeds, prompt_attention_mask = pipe.encode_prompt(
 
58
  device=device
59
  )
60
 
61
+ # Generate initial random latents
62
  init_latents = torch.randn(
63
+ [1, 32, 32, 32],
64
  device=device,
65
  dtype=dtype
66
  )
67
 
68
+ # Apply HyperNoise modulation with adapter enabled
69
  pipe.transformer.enable_adapter_layers()
70
  modulated_latents = pipe.transformer(
71
  hidden_states=init_latents,
 
77
 
78
  # Generate final image with adapter disabled
79
  pipe.transformer.disable_adapter_layers()
80
+ image = pipe(
81
+ latents=modulated_latents,
82
+ prompt_embeds=prompt_embeds,
83
+ prompt_attention_mask=prompt_attention_mask,
84
+ intermediate_steps=None,
85
+ num_inference_steps=num_inference_steps,
86
+ height=height,
87
+ width=width,
88
+ ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
  return image, seed
91