Cognomen commited on
Commit
1da736d
Β·
1 Parent(s): f3148bf

do more of what canny_coyo1m does

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import gradio as gr
2
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel
 
3
  from diffusers import UniPCMultistepScheduler
4
  import torch
5
  torch.backends.cuda.matmul.allow_tf32 = True
@@ -24,12 +25,12 @@ conditioning_image_transforms = T.Compose(
24
  ]
25
  )
26
 
27
- cnet = ControlNetModel.from_pretrained("./models/catcon-controlnet-wd", torch_dtype=torch.float16, from_flax=True).to("cuda")
28
- pipe = StableDiffusionControlNetPipeline.from_pretrained(
29
  "./models/wd-1-5-b2",
30
  controlnet=cnet,
31
- torch_dtype=torch.float16,
32
- ).to("cuda")
33
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
34
  #pipe.enable_model_cpu_offload()
35
  #pipe.enable_xformers_memory_efficient_attention()
@@ -41,15 +42,18 @@ def infer(prompt, negative_prompt, image):
41
  # implement your inference function here
42
  inp = Image.fromarray(image)
43
 
44
- cond_input = conditioning_image_transforms(inp).to("cpu", dtype=torch.float32)
45
  cond_input = T.ToPILImage()(cond_input)
 
 
46
 
47
  output = pipe(
48
  prompt,
49
- cond_input,
50
  generator=generator,
51
  num_images_per_prompt=4,
52
- num_inference_steps=20
 
53
  )
54
 
55
  return output.images
 
1
  import gradio as gr
2
+ import jax.numpy as jnp
3
+ from diffusers import FlaxStableDiffusionControlNetPipeline, FlaxControlNetModel
4
  from diffusers import UniPCMultistepScheduler
5
  import torch
6
  torch.backends.cuda.matmul.allow_tf32 = True
 
25
  ]
26
  )
27
 
28
+ cnet = FlaxControlNetModel.from_pretrained("./models/catcon-controlnet-wd", dtype=jnp.bfloat16, from_flax=True)
29
+ pipe = FlaxStableDiffusionControlNetPipeline.from_pretrained(
30
  "./models/wd-1-5-b2",
31
  controlnet=cnet,
32
+ dtype=jnp.bfloat16,
33
+ )
34
  pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
35
  #pipe.enable_model_cpu_offload()
36
  #pipe.enable_xformers_memory_efficient_attention()
 
42
  # implement your inference function here
43
  inp = Image.fromarray(image)
44
 
45
+ cond_input = conditioning_image_transforms(inp)
46
  cond_input = T.ToPILImage()(cond_input)
47
+
48
+ cond_in = pipe.prepare_image_inputs([cond_input] * 4)
49
 
50
  output = pipe(
51
  prompt,
52
+ cond_in,
53
  generator=generator,
54
  num_images_per_prompt=4,
55
+ num_inference_steps=20,
56
+ jit=True
57
  )
58
 
59
  return output.images