Sebastian Semeniuc
commited on
Commit
·
a1f58f2
1
Parent(s):
b9fbc24
fix: fix running handler and add test file for catching potential bugs
Browse files- handler.py +3 -4
- test.py +19 -0
handler.py
CHANGED
@@ -100,10 +100,10 @@ class EndpointHandler():
|
|
100 |
|
101 |
# Load StableDiffusionControlNetPipeline
|
102 |
self.sdxl_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
103 |
-
self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
|
104 |
self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(self.sdxl_id,
|
105 |
controlnet=self.controlnet,
|
106 |
-
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True),
|
107 |
torch_dtype=torch.float16,
|
108 |
use_safetensors=True)
|
109 |
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
@@ -149,7 +149,7 @@ class EndpointHandler():
|
|
149 |
control_image = SDXLCONTROLNET_MAPPING[self.control_type]["hinter"](
|
150 |
image, width=1024, height=1024)
|
151 |
|
152 |
-
generator = torch.manual_seed(1)
|
153 |
|
154 |
# run inference pipeline
|
155 |
images = self.pipe(
|
@@ -158,7 +158,6 @@ class EndpointHandler():
|
|
158 |
image=control_image,
|
159 |
# width=width,
|
160 |
# height=height,
|
161 |
-
generator=generator,
|
162 |
num_inference_steps=num_inference_steps,
|
163 |
guidance_scale=guidance_scale,
|
164 |
num_images_per_prompt=num_of_images,
|
|
|
100 |
|
101 |
# Load StableDiffusionControlNetPipeline
|
102 |
self.sdxl_id = "stabilityai/stable-diffusion-xl-base-1.0"
|
103 |
+
# self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
|
104 |
self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(self.sdxl_id,
|
105 |
controlnet=self.controlnet,
|
106 |
+
# vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True),
|
107 |
torch_dtype=torch.float16,
|
108 |
use_safetensors=True)
|
109 |
self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
|
|
|
149 |
control_image = SDXLCONTROLNET_MAPPING[self.control_type]["hinter"](
|
150 |
image, width=1024, height=1024)
|
151 |
|
152 |
+
self.generator = torch.manual_seed(1)
|
153 |
|
154 |
# run inference pipeline
|
155 |
images = self.pipe(
|
|
|
158 |
image=control_image,
|
159 |
# width=width,
|
160 |
# height=height,
|
|
|
161 |
num_inference_steps=num_inference_steps,
|
162 |
guidance_scale=guidance_scale,
|
163 |
num_images_per_prompt=num_of_images,
|
test.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from handler import EndpointHandler
|
2 |
+
|
3 |
+
# init handler
|
4 |
+
my_handler = EndpointHandler(path=".")
|
5 |
+
|
6 |
+
# prepare sample payload
|
7 |
+
non_holiday_payload = {"inputs": "I am quite excited how this will turn out", "date": "2022-08-08"}
|
8 |
+
holiday_payload = {"inputs": "Today is a though day", "date": "2022-07-04"}
|
9 |
+
|
10 |
+
# test the handler
|
11 |
+
non_holiday_pred=my_handler(non_holiday_payload)
|
12 |
+
holiday_payload=my_handler(holiday_payload)
|
13 |
+
|
14 |
+
# show results
|
15 |
+
print("non_holiday_pred", non_holiday_pred)
|
16 |
+
print("holiday_payload", holiday_payload)
|
17 |
+
|
18 |
+
# non_holiday_pred [{'label': 'joy', 'score': 0.9985942244529724}]
|
19 |
+
# holiday_payload [{'label': 'happy', 'score': 1}]
|