Cognomen commited on
Commit
eeddd9f
Β·
1 Parent(s): 23ace94

do what coyo_1m space does

Browse files
Files changed (1) hide show
  1. app.py +25 -8
app.py CHANGED
@@ -6,6 +6,8 @@ import torch
6
  torch.backends.cuda.matmul.allow_tf32 = True
7
  import torchvision
8
  import torchvision.transforms as T
 
 
9
  #from torchvision.transforms import v2 as T2
10
  import cv2
11
  import PIL
@@ -25,8 +27,8 @@ conditioning_image_transforms = T.Compose(
25
  ]
26
  )
27
 
28
- cnet = FlaxControlNetModel.from_pretrained("./models/catcon-controlnet-wd", dtype=jnp.bfloat16, from_flax=True)
29
- pipe = FlaxStableDiffusionControlNetPipeline.from_pretrained(
30
  "./models/wd-1-5-b2",
31
  controlnet=cnet,
32
  dtype=jnp.bfloat16,
@@ -36,23 +38,38 @@ pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
36
  #pipe.enable_model_cpu_offload()
37
  #pipe.enable_xformers_memory_efficient_attention()
38
 
39
- generator = torch.manual_seed(0)
 
40
 
41
  # inference function takes prompt, negative prompt and image
42
  def infer(prompt, negative_prompt, image):
43
  # implement your inference function here
 
 
 
44
  inp = Image.fromarray(image)
45
 
46
  cond_input = conditioning_image_transforms(inp)
47
  cond_input = T.ToPILImage()(cond_input)
48
 
49
- cond_in = pipe.prepare_image_inputs([cond_input] * 4)
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  output = pipe(
52
- prompt,
53
- cond_in,
54
- generator=generator,
55
- num_images_per_prompt=4,
56
  num_inference_steps=20,
57
  jit=True
58
  )
 
6
  torch.backends.cuda.matmul.allow_tf32 = True
7
  import torchvision
8
  import torchvision.transforms as T
9
+ from flax.jax_utils import replicate
10
+ from flax.training.common_utils import shard
11
  #from torchvision.transforms import v2 as T2
12
  import cv2
13
  import PIL
 
27
  ]
28
  )
29
 
30
+ cnet, cnet_params = FlaxControlNetModel.from_pretrained("./models/catcon-controlnet-wd", dtype=jnp.bfloat16, from_flax=True)
31
+ pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
32
  "./models/wd-1-5-b2",
33
  controlnet=cnet,
34
  dtype=jnp.bfloat16,
 
38
  #pipe.enable_model_cpu_offload()
39
  #pipe.enable_xformers_memory_efficient_attention()
40
 
41
+ def get_random(seed):
42
+ jax.random.PRNGKey(seed)
43
 
44
  # inference function takes prompt, negative prompt and image
45
  def infer(prompt, negative_prompt, image):
46
  # implement your inference function here
47
+ params["controlnet"] = cnet_params
48
+ num_samples = 1
49
+
50
  inp = Image.fromarray(image)
51
 
52
  cond_input = conditioning_image_transforms(inp)
53
  cond_input = T.ToPILImage()(cond_input)
54
 
55
+ cond_img_in = pipe.prepare_image_inputs([cond_input] * num_samples)
56
+
57
+ prompt_in = pipe.prepare_text_inputs([prompt] * num_samples)
58
+ prompt_in = shard(prompt_in)
59
+
60
+ n_prompt_in = pipe.prepare_text_inputs([negative_prompt] * num_samples)
61
+ n_prompt_in = shard(n_prompt_in)
62
+
63
+ rng = get_random(0)
64
+ rng.random.split(rng, jax.device_count())
65
+
66
+ p_params = replicate(params)
67
 
68
  output = pipe(
69
+ prompt_ids=prompts_in,
70
+ image=cond_img_in,
71
+ prng_seed=rng,
72
+ neg_prompt_ids=n_prompt_in,
73
  num_inference_steps=20,
74
  jit=True
75
  )