Update handler.py
Browse files- handler.py +83 -15
handler.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from typing import Dict, List, Any
|
2 |
import torch
|
3 |
-
from diffusers import StableDiffusionXLImg2ImgPipeline, DiffusionPipeline, AutoencoderKL
|
4 |
from PIL import Image
|
5 |
import base64
|
6 |
from io import BytesIO
|
@@ -24,6 +24,22 @@ class EndpointHandler():
|
|
24 |
self.vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",
|
25 |
subfolder="vae", use_safetensors=True,
|
26 |
).to("cuda")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
|
28 |
|
29 |
|
@@ -32,25 +48,67 @@ class EndpointHandler():
|
|
32 |
:param data: A dictionary contains `inputs` and optional `image` field.
|
33 |
:return: A dictionary with `image` field contains image in base64.
|
34 |
"""
|
35 |
-
encoded_image = data.pop("image", None)
|
36 |
-
|
37 |
-
prompt = data.pop("prompt", "")
|
38 |
-
num_inference_steps = data.pop("num_inference_steps", 50)
|
39 |
|
40 |
-
|
41 |
-
image = self.decode_base64_image(encoded_image).convert('RGB')
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
|
47 |
-
|
48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
-
|
51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
# helper to decode input image
|
56 |
def decode_base64_image(self, image_string):
|
@@ -58,3 +116,13 @@ class EndpointHandler():
|
|
58 |
buffer = BytesIO(base64_image)
|
59 |
image = Image.open(buffer)
|
60 |
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from typing import Dict, List, Any
|
2 |
import torch
|
3 |
+
from diffusers import StableDiffusionXLImg2ImgPipeline, DiffusionPipeline, AutoencoderKL, DPMSolverMultistepScheduler, DDIMScheduler, StableDiffusionInpaintPipeline, AutoPipelineForInpainting, AutoPipelineForImage2Image, StableDiffusionControlNetInpaintPipeline, ControlNetModel
|
4 |
from PIL import Image
|
5 |
import base64
|
6 |
from io import BytesIO
|
|
|
24 |
self.vae = AutoencoderKL.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",
|
25 |
subfolder="vae", use_safetensors=True,
|
26 |
).to("cuda")
|
27 |
+
|
28 |
+
self.smooth_pipe.enable_model_cpu_offload()
|
29 |
+
self.enable_xformers_memory_efficient_attention()
|
30 |
+
|
31 |
+
|
32 |
+
self.controlnet = ControlNetModel.from_pretrained(
|
33 |
+
"lllyasviel/control_v11p_sd15_inpaint", torch_dtype=torch.float16
|
34 |
+
)
|
35 |
+
|
36 |
+
self.pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
|
37 |
+
"runwayml/stable-diffusion-v1-5", controlnet=self.controlnet, torch_dtype=torch.float16
|
38 |
+
)
|
39 |
+
|
40 |
+
self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)
|
41 |
+
self.pipe.enable_model_cpu_offload()
|
42 |
+
self.pipe.enable_xformers_memory_efficient_attention()
|
43 |
|
44 |
|
45 |
|
|
|
48 |
:param data: A dictionary contains `inputs` and optional `image` field.
|
49 |
:return: A dictionary with `image` field contains image in base64.
|
50 |
"""
|
|
|
|
|
|
|
|
|
51 |
|
52 |
+
method = data.pop("method", "rasterize")
|
|
|
53 |
|
54 |
+
if(method == "rasterize"):
|
55 |
+
|
56 |
+
encoded_image = data.pop("image", None)
|
57 |
|
58 |
+
prompt = data.pop("prompt", "")
|
59 |
+
num_inference_steps = data.pop("num_inference_steps", 50)
|
60 |
+
|
61 |
+
if encoded_image is not None:
|
62 |
+
image = self.decode_base64_image(encoded_image).convert('RGB')
|
63 |
+
|
64 |
+
image_processor = VaeImageProcessor();
|
65 |
+
latents = image_processor.preprocess(image)
|
66 |
+
latents = latents.to(device="cuda")
|
67 |
+
|
68 |
+
with torch.no_grad():
|
69 |
+
latents_dist = self.vae.encode(latents).latent_dist.sample() * self.vae.config.scaling_factor
|
70 |
+
|
71 |
+
self.smooth_pipe.enable_xformers_memory_efficient_attention()
|
72 |
+
out = self.smooth_pipe(prompt, image=latents_dist, num_inference_steps=num_inference_steps).images
|
73 |
+
|
74 |
+
return out
|
75 |
+
else:
|
76 |
+
encoded_image = data.pop("image", None)
|
77 |
+
encoded_mask_image = data.pop("mask_image", None)
|
78 |
|
79 |
+
prompt = data.pop("prompt", "")
|
80 |
+
|
81 |
+
negative_prompt = data.pop("negative_prompt", "")
|
82 |
+
|
83 |
+
method = data.pop("method", "slow")
|
84 |
+
strength = data.pop("strength", 0.2)
|
85 |
+
guidance_scale = data.pop("guidance_scale", 8.0)
|
86 |
+
num_inference_steps = data.pop("num_inference_steps", 20)
|
87 |
+
|
88 |
+
# process image
|
89 |
+
if encoded_image is not None and encoded_mask_image is not None:
|
90 |
+
image = self.decode_base64_image(encoded_image).convert("RGB")
|
91 |
+
mask_image = self.decode_base64_image(encoded_mask_image).convert("RGB")
|
92 |
+
else:
|
93 |
+
image = None
|
94 |
+
mask_image = None
|
95 |
|
96 |
+
control_image = self.make_inpaint_condition(image, mask_image)
|
97 |
+
|
98 |
+
# generate image
|
99 |
+
image = self.pipe(
|
100 |
+
prompt=prompt,
|
101 |
+
negative_prompt=negative_prompt,
|
102 |
+
num_inference_steps=num_inference_steps,
|
103 |
+
eta=1.0,
|
104 |
+
image=image,
|
105 |
+
mask_image=mask_image,
|
106 |
+
control_image=control_image,
|
107 |
+
guidance_scale=guidance_scale,
|
108 |
+
strength=strength
|
109 |
+
).images[0]
|
110 |
+
|
111 |
+
return image
|
112 |
|
113 |
# helper to decode input image
|
114 |
def decode_base64_image(self, image_string):
|
|
|
116 |
buffer = BytesIO(base64_image)
|
117 |
image = Image.open(buffer)
|
118 |
return image
|
119 |
+
|
120 |
+
def make_inpaint_condition(self, image, image_mask):
|
121 |
+
image = np.array(image.convert("RGB")).astype(np.float32) / 255.0
|
122 |
+
image_mask = np.array(image_mask.convert("L")).astype(np.float32) / 255.0
|
123 |
+
|
124 |
+
assert image.shape[0:1] == image_mask.shape[0:1], "image and image_mask must have the same image size"
|
125 |
+
image[image_mask > 0.5] = -1.0 # set as masked pixel
|
126 |
+
image = np.expand_dims(image, 0).transpose(0, 3, 1, 2)
|
127 |
+
image = torch.from_numpy(image)
|
128 |
+
return image
|