Spaces:
Runtime error
Runtime error
use revision=flax
Browse files
app.py
CHANGED
@@ -31,8 +31,8 @@ cnet, cnet_params = FlaxControlNetModel.from_pretrained("./models/catcon-control
|
|
31 |
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
32 |
"./models/wd-1-5-b2",
|
33 |
controlnet=cnet,
|
|
|
34 |
dtype=jnp.bfloat16,
|
35 |
-
from_pt=True
|
36 |
)
|
37 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
38 |
#pipe.enable_model_cpu_offload()
|
@@ -72,9 +72,10 @@ def infer(prompt, negative_prompt, image):
|
|
72 |
neg_prompt_ids=n_prompt_in,
|
73 |
num_inference_steps=20,
|
74 |
jit=True
|
75 |
-
)
|
76 |
|
77 |
-
|
|
|
78 |
|
79 |
gr.Interface(
|
80 |
infer,
|
|
|
31 |
pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
|
32 |
"./models/wd-1-5-b2",
|
33 |
controlnet=cnet,
|
34 |
+
revision="flax"
|
35 |
dtype=jnp.bfloat16,
|
|
|
36 |
)
|
37 |
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
|
38 |
#pipe.enable_model_cpu_offload()
|
|
|
72 |
neg_prompt_ids=n_prompt_in,
|
73 |
num_inference_steps=20,
|
74 |
jit=True
|
75 |
+
).images
|
76 |
|
77 |
+
output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
|
78 |
+
return output_images
|
79 |
|
80 |
gr.Interface(
|
81 |
infer,
|