multimodalart HF Staff commited on
Commit
85cc5c2
·
verified ·
1 Parent(s): 606b9af

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -15
app.py CHANGED
@@ -83,27 +83,55 @@ def infer(prompt, seed=42, randomize_seed=False, width=1024, height=1024,
83
  # Generate final image with adapter disabled
84
  pipe.transformer.disable_adapter_layers()
85
 
86
- # Manually set timesteps to avoid intermediate_timesteps issue
87
- # SCM scheduler only supports intermediate_timesteps when num_inference_steps=2
 
88
  if num_inference_steps == 2:
89
- pipe.scheduler.set_timesteps(num_inference_steps, device=device)
 
 
 
 
 
 
90
  else:
91
- # For num_inference_steps != 2, we need to avoid intermediate_timesteps
92
- max_timesteps = 1.57080 # Default from SCM paper
93
  pipe.scheduler.set_timesteps(
94
  num_inference_steps,
95
  device=device,
96
- max_timesteps=max_timesteps,
97
- intermediate_timesteps=None
98
  )
99
-
100
- # Now generate the image with pre-set timesteps
101
- image = pipe(
102
- latents=modulated_latents,
103
- prompt_embeds=prompt_embeds,
104
- prompt_attention_mask=prompt_attention_mask,
105
- num_inference_steps=num_inference_steps,
106
- ).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  return image, seed
109
 
 
83
  # Generate final image with adapter disabled
84
  pipe.transformer.disable_adapter_layers()
85
 
86
+ # For SCM scheduler, we need to handle the timesteps carefully
87
+ # The pipeline expects intermediate_timesteps only when num_inference_steps=2
88
+ # For other values, we use the workaround from the original code
89
  if num_inference_steps == 2:
90
+ # Use the default pipeline behavior for 2 steps
91
+ image = pipe(
92
+ latents=modulated_latents,
93
+ prompt_embeds=prompt_embeds,
94
+ prompt_attention_mask=prompt_attention_mask,
95
+ num_inference_steps=num_inference_steps,
96
+ ).images[0]
97
  else:
98
+ # For num_inference_steps != 2, we need to work around the restriction
99
+ # by directly calling the denoising loop
100
  pipe.scheduler.set_timesteps(
101
  num_inference_steps,
102
  device=device,
103
+ timesteps=torch.linspace(1.57080, 0, num_inference_steps + 1, device=device)
 
104
  )
105
+
106
+ # Run the denoising loop manually
107
+ latents = modulated_latents
108
+ for i, t in enumerate(pipe.scheduler.timesteps[:-1]):
109
+ # Expand timestep to match batch dimension
110
+ timestep = t.expand(latents.shape[0])
111
+
112
+ # Predict noise
113
+ noise_pred = pipe.transformer(
114
+ hidden_states=latents,
115
+ encoder_hidden_states=prompt_embeds,
116
+ encoder_attention_mask=prompt_attention_mask,
117
+ timestep=timestep,
118
+ guidance=torch.tensor([0.0], device=device, dtype=dtype), # No guidance for denoising
119
+ return_dict=False,
120
+ )[0]
121
+
122
+ # Compute previous noisy sample
123
+ latents = pipe.scheduler.step(
124
+ noise_pred,
125
+ t,
126
+ latents,
127
+ return_dict=False
128
+ )[0]
129
+
130
+ # Decode latents to image
131
+ latents = pipe._unpack_latents(latents, height, width, pipe.vae_scale_factor)
132
+ latents = (latents / pipe.vae.scaling_factor) + pipe.vae.shift_factor
133
+ image = pipe.vae.decode(latents, return_dict=False)[0]
134
+ image = pipe.image_processor.postprocess(image, output_type="pil")[0]
135
 
136
  return image, seed
137