Cognomen commited on
Commit
c848fb3
Β·
1 Parent(s): eeddd9f

use revision=flax

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -31,8 +31,8 @@ cnet, cnet_params = FlaxControlNetModel.from_pretrained("./models/catcon-control
31
  pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
32
  "./models/wd-1-5-b2",
33
  controlnet=cnet,
 
34
  dtype=jnp.bfloat16,
35
- from_pt=True
36
  )
37
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
38
  #pipe.enable_model_cpu_offload()
@@ -72,9 +72,10 @@ def infer(prompt, negative_prompt, image):
72
  neg_prompt_ids=n_prompt_in,
73
  num_inference_steps=20,
74
  jit=True
75
- )
76
 
77
- return output.images
 
78
 
79
  gr.Interface(
80
  infer,
 
31
  pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
32
  "./models/wd-1-5-b2",
33
  controlnet=cnet,
34
+ revision="flax"
35
  dtype=jnp.bfloat16,
 
36
  )
37
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
38
  #pipe.enable_model_cpu_offload()
 
72
  neg_prompt_ids=n_prompt_in,
73
  num_inference_steps=20,
74
  jit=True
75
+ ).images
76
 
77
+ output_images = pipe.numpy_to_pil(np.asarray(output.reshape((num_samples,) + output.shape[-3:])))
78
+ return output_images
79
 
80
  gr.Interface(
81
  infer,