Sebastian Semeniuc commited on
Commit
fae0531
·
1 Parent(s): 4cc0dca

feat: add sdxl with controlnet

Browse files
Files changed (3) hide show
  1. handler.py +41 -32
  2. request.json +0 -0
  3. requirements.txt +1 -1
handler.py CHANGED
@@ -1,8 +1,8 @@
1
- from typing import Dict, List, Any
2
  import base64
3
  from PIL import Image
4
  from io import BytesIO
5
- from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
6
  import torch
7
 
8
 
@@ -15,7 +15,16 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
  if device.type != 'cuda':
16
  raise ValueError("need to run on GPU")
17
  # set mixed precision dtype
18
- dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
 
 
 
 
 
 
 
 
 
19
 
20
  # controlnet mapping for controlnet id and control hinter
21
  CONTROLNET_MAPPING = {
@@ -58,14 +67,16 @@ class EndpointHandler():
58
  def __init__(self, path=""):
59
  # define default controlnet id and load controlnet
60
  self.control_type = "normal"
61
- self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],torch_dtype=dtype).to(device)
62
-
63
- # Load StableDiffusionControlNetPipeline
 
 
64
  self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
65
- self.pipe = StableDiffusionControlNetPipeline.from_pretrained(self.stable_diffusion_id,
66
- controlnet=self.controlnet,
67
- torch_dtype=dtype,
68
- safety_checker=None).to(device)
69
  # makes inference much faster
70
  # self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
71
  # Define Generator with seed
@@ -78,55 +89,53 @@ class EndpointHandler():
78
  """
79
  prompt = data.pop("inputs", None)
80
  image = data.pop("image", None)
81
- num_of_images = data.pop("numOfImages", None)
82
  controlnet_type = data.pop("controlnet_type", None)
83
-
84
  # Check if neither prompt nor image is provided
85
  if prompt is None and image is None:
86
  return {"error": "Please provide a prompt and base64 encoded image."}
87
-
88
  if num_of_images is None:
89
  num_of_images = 1
90
-
91
  # Check if a new controlnet is provided
92
  if controlnet_type is not None and controlnet_type != self.control_type:
93
- print(f"changing controlnet from {self.control_type} to {controlnet_type} using {CONTROLNET_MAPPING[controlnet_type]['model_id']} model")
 
94
  self.control_type = controlnet_type
95
- self.controlnet = ControlNetModel.from_pretrained(CONTROLNET_MAPPING[self.control_type]["model_id"],
96
  torch_dtype=dtype).to(device)
97
  self.pipe.controlnet = self.controlnet
98
-
99
-
100
  # hyperparamters
101
  num_inference_steps = data.pop("num_inference_steps", 30)
102
  guidance_scale = data.pop("guidance_scale", 7.5)
103
  negative_prompt = data.pop("negative_prompt", None)
104
  height = data.pop("height", None)
105
  width = data.pop("width", None)
106
- controlnet_conditioning_scale = data.pop("controlnet_conditioning_scale", 1.0)
107
-
 
108
  # process image
109
  image = self.decode_base64_image(image)
110
- control_image = CONTROLNET_MAPPING[self.control_type]["hinter"](image)
111
-
 
112
  # run inference pipeline
113
- out = self.pipe(
114
- prompt=prompt,
115
  negative_prompt=negative_prompt,
116
  image=control_image,
117
- num_inference_steps=num_inference_steps,
118
  guidance_scale=guidance_scale,
119
  num_images_per_prompt=num_of_images,
120
- height=height,
121
- width=width,
122
  controlnet_conditioning_scale=controlnet_conditioning_scale,
123
  generator=self.generator
124
- )
 
 
125
 
126
-
127
- # return the list of generated images
128
- return out.images
129
-
130
  # helper to decode input image
131
  def decode_base64_image(self, image_string):
132
  base64_image = base64.b64decode(image_string)
 
1
+ from typing import Dict, List, Any
2
  import base64
3
  from PIL import Image
4
  from io import BytesIO
5
+ from diffusers import StableDiffusionXLControlNetPipeline, ControlNetModel, AutoencoderKL, UniPCMultistepScheduler
6
  import torch
7
 
8
 
 
15
  if device.type != 'cuda':
16
  raise ValueError("need to run on GPU")
17
  # set mixed precision dtype
18
+ dtype = torch.bfloat16 if torch.cuda.get_device_capability()[
19
+ 0] == 8 else torch.float16
20
+
21
+ # for the moment, support only canny edge
22
+ SDXLCONTROLNET_MAPPING = {
23
+ "canny_edge": {
24
+ "model_id": "diffusers/controlnet-canny-sdxl-1.0",
25
+ "hinter": controlnet_hinter.hint_canny
26
+ }
27
+ }
28
 
29
  # controlnet mapping for controlnet id and control hinter
30
  CONTROLNET_MAPPING = {
 
67
  def __init__(self, path=""):
68
  # define default controlnet id and load controlnet
69
  self.control_type = "normal"
70
+ self.controlnet = ControlNetModel.from_pretrained(
71
+ SDXLCONTROLNET_MAPPING[self.control_type]["model_id"], torch_dtype=dtype).to(device)
72
+
73
+ # Load StableDiffusionControlNetPipeline
74
+ self.sdxl_id = "stabilityai/stable-diffusion-xl-base-1.0"
75
  self.stable_diffusion_id = "runwayml/stable-diffusion-v1-5"
76
+ self.pipe = StableDiffusionXLControlNetPipeline.from_pretrained(self.sdxl_id,
77
+ controlnet=self.controlnet,
78
+ torch_dtype=dtype,
79
+ safety_checker=None).to(device)
80
  # makes inference much faster
81
  # self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
82
  # Define Generator with seed
 
89
  """
90
  prompt = data.pop("inputs", None)
91
  image = data.pop("image", None)
92
+ num_of_images = data.pop("num_of_images", None)
93
  controlnet_type = data.pop("controlnet_type", None)
94
+
95
  # Check if neither prompt nor image is provided
96
  if prompt is None and image is None:
97
  return {"error": "Please provide a prompt and base64 encoded image."}
98
+
99
  if num_of_images is None:
100
  num_of_images = 1
101
+
102
  # Check if a new controlnet is provided
103
  if controlnet_type is not None and controlnet_type != self.control_type:
104
+ print(
105
+ f"changing controlnet from {self.control_type} to {controlnet_type} using {SDXLCONTROLNET_MAPPING[controlnet_type]['model_id']} model")
106
  self.control_type = controlnet_type
107
+ self.controlnet = ControlNetModel.from_pretrained(SDXLCONTROLNET_MAPPING[self.control_type]["model_id"],
108
  torch_dtype=dtype).to(device)
109
  self.pipe.controlnet = self.controlnet
110
+
 
111
  # hyperparamters
112
  num_inference_steps = data.pop("num_inference_steps", 30)
113
  guidance_scale = data.pop("guidance_scale", 7.5)
114
  negative_prompt = data.pop("negative_prompt", None)
115
  height = data.pop("height", None)
116
  width = data.pop("width", None)
117
+ controlnet_conditioning_scale = data.pop(
118
+ "controlnet_conditioning_scale", 1.0)
119
+
120
  # process image
121
  image = self.decode_base64_image(image)
122
+ control_image = SDXLCONTROLNET_MAPPING[self.control_type]["hinter"](
123
+ image, width=1024, height=1024)
124
+
125
  # run inference pipeline
126
+ images = self.pipe(
127
+ prompt=prompt,
128
  negative_prompt=negative_prompt,
129
  image=control_image,
130
+ num_inference_steps=num_inference_steps,
131
  guidance_scale=guidance_scale,
132
  num_images_per_prompt=num_of_images,
 
 
133
  controlnet_conditioning_scale=controlnet_conditioning_scale,
134
  generator=self.generator
135
+ ).images[0]
136
+
137
+ return images
138
 
 
 
 
 
139
  # helper to decode input image
140
  def decode_base64_image(self, image_string):
141
  base64_image = base64.b64decode(image_string)
request.json CHANGED
The diff for this file is too large to render. See raw diff
 
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- diffusers==0.19.3
2
  safetensors
3
  opencv-python
4
  controlnet_hinter==0.0.5
 
1
+ diffusers==0.20.0
2
  safetensors
3
  opencv-python
4
  controlnet_hinter==0.0.5