Update handler.py
Browse files- handler.py +38 -3
handler.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
from typing import Dict, List, Any
|
2 |
import torch
|
3 |
-
from diffusers import DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline
|
4 |
from PIL import Image
|
5 |
import base64
|
6 |
from io import BytesIO
|
@@ -15,7 +15,7 @@ if device.type != 'cuda':
|
|
15 |
class EndpointHandler():
|
16 |
def __init__(self, path=""):
|
17 |
# load StableDiffusionInpaintPipeline pipeline
|
18 |
-
self.pipe =
|
19 |
"runwayml/stable-diffusion-inpainting",
|
20 |
revision="fp16",
|
21 |
torch_dtype=torch.float16,
|
@@ -25,6 +25,12 @@ class EndpointHandler():
|
|
25 |
# move to device
|
26 |
self.pipe = self.pipe.to(device)
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
30 |
"""
|
@@ -43,12 +49,41 @@ class EndpointHandler():
|
|
43 |
else:
|
44 |
image = None
|
45 |
mask_image = None
|
|
|
|
|
46 |
|
47 |
# run inference pipeline
|
48 |
out = self.pipe(prompt=prompt, image=image, mask_image=mask_image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
|
50 |
# return first generate PIL image
|
51 |
-
return
|
52 |
|
53 |
# helper to decode input image
|
54 |
def decode_base64_image(self, image_string):
|
|
|
1 |
from typing import Dict, List, Any
|
2 |
import torch
|
3 |
+
from diffusers import DPMSolverMultistepScheduler, StableDiffusionInpaintPipeline, AutoPipelineForInpainting, AutoPipelineForImage2Image
|
4 |
from PIL import Image
|
5 |
import base64
|
6 |
from io import BytesIO
|
|
|
15 |
class EndpointHandler():
|
16 |
def __init__(self, path=""):
|
17 |
# load StableDiffusionInpaintPipeline pipeline
|
18 |
+
self.pipe = AutoPipelineForInpainting.from_pretrained(
|
19 |
"runwayml/stable-diffusion-inpainting",
|
20 |
revision="fp16",
|
21 |
torch_dtype=torch.float16,
|
|
|
25 |
# move to device
|
26 |
self.pipe = self.pipe.to(device)
|
27 |
|
28 |
+
self.pipe2 = AutoPipelineForInpainting.from_pretrained("stabilityai/stable-diffusion-xl-refiner-1.0", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
|
29 |
+
self.pipe2.to("cuda")
|
30 |
+
|
31 |
+
self.pipe3 = AutoPipelineForImage2Image.from_pipe(self.pipe2)
|
32 |
+
|
33 |
+
|
34 |
|
35 |
def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
|
36 |
"""
|
|
|
49 |
else:
|
50 |
image = None
|
51 |
mask_image = None
|
52 |
+
|
53 |
+
self.pipe.enable_xformers_memory_efficient_attention()
|
54 |
|
55 |
# run inference pipeline
|
56 |
out = self.pipe(prompt=prompt, image=image, mask_image=mask_image)
|
57 |
+
|
58 |
+
image = out.images[0].resize((1024, 1024))
|
59 |
+
|
60 |
+
self.pipe2.enable_xformers_memory_efficient_attention()
|
61 |
+
|
62 |
+
image = pipe(
|
63 |
+
prompt=prompt,
|
64 |
+
image=image,
|
65 |
+
mask_image=mask_image,
|
66 |
+
guidance_scale=8.0,
|
67 |
+
num_inference_steps=100,
|
68 |
+
strength=0.2,
|
69 |
+
generator=generator,
|
70 |
+
output_type="latent", # let's keep in latent to save some VRAM
|
71 |
+
).images[0]
|
72 |
+
|
73 |
+
pipe = AutoPipelineForImage2Image.from_pipe(pipe)
|
74 |
+
self.pipe3.enable_xformers_memory_efficient_attention()
|
75 |
+
|
76 |
+
image = pipe(
|
77 |
+
prompt=prompt,
|
78 |
+
image=image,
|
79 |
+
guidance_scale=8.0,
|
80 |
+
num_inference_steps=100,
|
81 |
+
strength=0.2,
|
82 |
+
generator=generator,
|
83 |
+
).images[0]
|
84 |
|
85 |
# return first generate PIL image
|
86 |
+
return image
|
87 |
|
88 |
# helper to decode input image
|
89 |
def decode_base64_image(self, image_string):
|