karimbenharrak commited on
Commit
0821367
·
verified ·
1 Parent(s): 0a1bf9a

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +83 -15
handler.py CHANGED
@@ -1,6 +1,6 @@
1
  from typing import Dict, List, Any
2
  import torch
3
- from diffusers import StableDiffusionXLImg2ImgPipeline, DiffusionPipeline, AutoencoderKL
4
  from PIL import Image
5
  import base64
6
  from io import BytesIO
@@ -24,6 +24,22 @@ class EndpointHandler():
24
  self.vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",
25
  subfolder="vae", use_safetensors=True,
26
  ).to("cuda")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
 
29
 
@@ -32,25 +48,67 @@ class EndpointHandler():
32
  :param data: A dictionary contains `inputs` and optional `image` field.
33
  :return: A dictionary with `image` field contains image in base64.
34
  """
35
- encoded_image = data.pop("image", None)
36
-
37
- prompt = data.pop("prompt", "")
38
- num_inference_steps = data.pop("num_inference_steps", 50)
39
 
40
- if encoded_image is not None:
41
- image = self.decode_base64_image(encoded_image).convert('RGB')
42
 
43
- image_processor = VaeImageProcessor();
44
- latents = image_processor.preprocess(image)
45
- latents = latents.to(device="cuda")
46
 
47
- with torch.no_grad():
48
- latents_dist = self.vae.encode(latents).latent_dist.sample() * self.vae.config.scaling_factor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
- self.smooth_pipe.enable_xformers_memory_efficient_attention()
51
- out = self.smooth_pipe(prompt, image=latents_dist, num_inference_steps=num_inference_steps).images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- return out
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  # helper to decode input image
56
  def decode_base64_image(self, image_string):
@@ -58,3 +116,13 @@ class EndpointHandler():
58
  buffer = BytesIO(base64_image)
59
  image = Image.open(buffer)
60
  return image
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Dict, List, Any
2
  import torch
3
+ from diffusers import StableDiffusionXLImg2ImgPipeline, DiffusionPipeline, AutoencoderKL, DPMSolverMultistepScheduler, DDIMScheduler, StableDiffusionInpaintPipeline, AutoPipelineForInpainting, AutoPipelineForImage2Image, StableDiffusionControlNetInpaintPipeline, ControlNetModel
4
  from PIL import Image
5
  import base64
6
  from io import BytesIO
 
24
  self.vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",
25
  subfolder="vae", use_safetensors=True,
26
  ).to("cuda")
27
+
28
+ self.smooth_pipe.enable_model_cpu_offload()
29
+ self.enable_xformers_memory_efficient_attention()
30
+
31
+
32
+ self.controlnet = ControlNetModel.from_pretrained(
33
+ "lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
34
+ )
35
+
36
+ self.pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
37
+ "runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, torch_dtype=torch.float16
38
+ )
39
+
40
+ self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
41
+ self.pipe.enable_model_cpu_offload()
42
+ self.pipe.enable_xformers_memory_efficient_attention()
43
 
44
 
45
 
 
48
  :param data: A dictionary contains `inputs` and optional `image` field.
49
  :return: A dictionary with `image` field contains image in base64.
50
  """
 
 
 
 
51
 
52
+ method = data.pop("method", "rasterize")
 
53
 
54
+ if(method == "rasterize"):
55
+
56
+ encoded_image = data.pop("image", None)
57
 
58
+ prompt = data.pop("prompt", "")
59
+ num_inference_steps = data.pop("num_inference_steps", 50)
60
+
61
+ if encoded_image is not None:
62
+ image = self.decode_base64_image(encoded_image).convert('RGB')
63
+
64
+ image_processor = VaeImageProcessor();
65
+ latents = image_processor.preprocess(image)
66
+ latents = latents.to(device="cuda")
67
+
68
+ with torch.no_grad():
69
+ latents_dist = self.vae.encode(latents).latent_dist.sample() * self.vae.config.scaling_factor
70
+
71
+ self.smooth_pipe.enable_xformers_memory_efficient_attention()
72
+ out = self.smooth_pipe(prompt, image=latents_dist, num_inference_steps=num_inference_steps).images
73
+
74
+ return out
75
+ else:
76
+ encoded_image = data.pop("image", None)
77
+ encoded_mask_image = data.pop("mask_image", None)
78
 
79
+ prompt = data.pop("prompt", "")
80
+
81
+ negative_prompt = data.pop("negative_prompt", "")
82
+
83
+ method = data.pop("method", "slow")
84
+ strength = data.pop("strength", 0.2)
85
+ guidance_scale = data.pop("guidance_scale", 8.0)
86
+ num_inference_steps = data.pop("num_inference_steps", 20)
87
+
88
+ # process image
89
+ if encoded_image is not None and encoded_mask_image is not None:
90
+ image = self.decode_base64_image(encoded_image).convert("RGB")
91
+ mask_image = self.decode_base64_image(encoded_mask_image).convert("RGB")
92
+ else:
93
+ image = None
94
+ mask_image = None
95
 
96
+ control_image = self.make_inpaint_condition(image, mask_image)
97
+
98
+ # generate image
99
+ image = self.pipe(
100
+ prompt=prompt,
101
+ negative_prompt=negative_prompt,
102
+ num_inference_steps=num_inference_steps,
103
+ eta=1.0,
104
+ image=image,
105
+ mask_image=mask_image,
106
+ control_image=control_image,
107
+ guidance_scale=guidance_scale,
108
+ strength=strength
109
+ ).images[0]
110
+
111
+ return image
112
 
113
  # helper to decode input image
114
  def decode_base64_image(self, image_string):
 
116
  buffer = BytesIO(base64_image)
117
  image = Image.open(buffer)
118
  return image
119
+
120
+ def make_inpaint_condition(self, image, image_mask):
121
+ image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
122
+ image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
123
+
124
+ assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
125
+ image[image_mask > 0.5] = -1.0 # set as masked pixel
126
+ image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
127
+ image = torch.from_numpy(image)
128
+ return image