Sebastian Semeniuc commited on
Commit
a1f58f2
·
1 Parent(s): b9fbc24

fix: fix running handler and add test file for catching potential bugs

Browse files
Files changed (2) hide show
  1. handler.py +3 -4
  2. test.py +19 -0
handler.py CHANGED
@@ -100,10 +100,10 @@ class EndpointHandler():
100
 
101
  # Load StableDiffusionControlNetPipeline
102
  self.sdxl_id = "stabilityai/stable-diffusion-xl-base-1.0"
103
- self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
104
  self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(self.sdxl_id,
105
  controlnet=self.controlnet,
106
- vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True),
107
  torch_dtype=torch.float16,
108
  use_safetensors=True)
109
  self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
@@ -149,7 +149,7 @@ class EndpointHandler():
149
  control_image = SDXLCONTROLNET_MAPPING[self.control_type]["hinter"](
150
  image, width=1024, height=1024)
151
 
152
- generator = torch.manual_seed(1)
153
 
154
  # run inference pipeline
155
  images = self.pipe(
@@ -158,7 +158,6 @@ class EndpointHandler():
158
  image=control_image,
159
  # width=width,
160
  # height=height,
161
- generator=generator,
162
  num_inference_steps=num_inference_steps,
163
  guidance_scale=guidance_scale,
164
  num_images_per_prompt=num_of_images,
 
100
 
101
  # Load StableDiffusionControlNetPipeline
102
  self.sdxl_id = "stabilityai/stable-diffusion-xl-base-1.0"
103
+ # self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
104
  self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(self.sdxl_id,
105
  controlnet=self.controlnet,
106
+ # vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16, use_safetensors=True),
107
  torch_dtype=torch.float16,
108
  use_safetensors=True)
109
  self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
 
149
  control_image = SDXLCONTROLNET_MAPPING[self.control_type]["hinter"](
150
  image, width=1024, height=1024)
151
 
152
+ self.generator = torch.manual_seed(1)
153
 
154
  # run inference pipeline
155
  images = self.pipe(
 
158
  image=control_image,
159
  # width=width,
160
  # height=height,
 
161
  num_inference_steps=num_inference_steps,
162
  guidance_scale=guidance_scale,
163
  num_images_per_prompt=num_of_images,
test.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from handler import EndpointHandler
2
+
3
+ # init handler
4
+ my_handler = EndpointHandler(path=".")
5
+
6
+ # prepare sample payload
7
+ non_holiday_payload = {"inputs": "I am quite excited how this will turn out", "date": "2022-08-08"}
8
+ holiday_payload = {"inputs": "Today is a though day", "date": "2022-07-04"}
9
+
10
+ # test the handler
11
+ non_holiday_pred=my_handler(non_holiday_payload)
12
+ holiday_payload=my_handler(holiday_payload)
13
+
14
+ # show results
15
+ print("non_holiday_pred", non_holiday_pred)
16
+ print("holiday_payload", holiday_payload)
17
+
18
+ # non_holiday_pred [{'label': 'joy', 'score': 0.9985942244529724}]
19
+ # holiday_payload [{'label': 'happy', 'score': 1}]