Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
|
@@ -18,6 +18,64 @@ from diffusers import (
|
|
| 18 |
T2IAdapter,
|
| 19 |
)
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
ADAPTER_NAMES = [
|
| 22 |
"TencentARC/t2i-adapter-canny-sdxl-1.0",
|
| 23 |
"TencentARC/t2i-adapter-sketch-sdxl-1.0",
|
|
@@ -57,7 +115,7 @@ class LineartPreprocessor(Preprocessor):
|
|
| 57 |
return self.model.to(device)
|
| 58 |
|
| 59 |
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
| 60 |
-
return self.model(image, detect_resolution=
|
| 61 |
|
| 62 |
|
| 63 |
class MidasPreprocessor(Preprocessor):
|
|
@@ -273,6 +331,8 @@ class Model:
|
|
| 273 |
if apply_preprocess:
|
| 274 |
image = self.preprocessor(image)
|
| 275 |
|
|
|
|
|
|
|
| 276 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 277 |
out = self.pipe(
|
| 278 |
prompt=prompt,
|
|
|
|
| 18 |
T2IAdapter,
|
| 19 |
)
|
| 20 |
|
| 21 |
+
SD_XL_BASE_RATIOS = {
|
| 22 |
+
"0.5": (704, 1408),
|
| 23 |
+
"0.52": (704, 1344),
|
| 24 |
+
"0.57": (768, 1344),
|
| 25 |
+
"0.6": (768, 1280),
|
| 26 |
+
"0.68": (832, 1216),
|
| 27 |
+
"0.72": (832, 1152),
|
| 28 |
+
"0.78": (896, 1152),
|
| 29 |
+
"0.82": (896, 1088),
|
| 30 |
+
"0.88": (960, 1088),
|
| 31 |
+
"0.94": (960, 1024),
|
| 32 |
+
"1.0": (1024, 1024),
|
| 33 |
+
"1.07": (1024, 960),
|
| 34 |
+
"1.13": (1088, 960),
|
| 35 |
+
"1.21": (1088, 896),
|
| 36 |
+
"1.29": (1152, 896),
|
| 37 |
+
"1.38": (1152, 832),
|
| 38 |
+
"1.46": (1216, 832),
|
| 39 |
+
"1.67": (1280, 768),
|
| 40 |
+
"1.75": (1344, 768),
|
| 41 |
+
"1.91": (1344, 704),
|
| 42 |
+
"2.0": (1408, 704),
|
| 43 |
+
"2.09": (1472, 704),
|
| 44 |
+
"2.4": (1536, 640),
|
| 45 |
+
"2.5": (1600, 640),
|
| 46 |
+
"2.89": (1664, 576),
|
| 47 |
+
"3.0": (1728, 576),
|
| 48 |
+
}
|
| 49 |
+
|
| 50 |
+
def find_closest_aspect_ratio(target_width, target_height):
|
| 51 |
+
target_ratio = target_width / target_height
|
| 52 |
+
closest_ratio = None
|
| 53 |
+
min_difference = float('inf')
|
| 54 |
+
|
| 55 |
+
for ratio_str, (width, height) in SD_XL_BASE_RATIOS.items():
|
| 56 |
+
ratio = width / height
|
| 57 |
+
difference = abs(target_ratio - ratio)
|
| 58 |
+
|
| 59 |
+
if difference < min_difference:
|
| 60 |
+
min_difference = difference
|
| 61 |
+
closest_ratio = ratio_str
|
| 62 |
+
|
| 63 |
+
return closest_ratio
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def resize_to_closest_aspect_ratio(image):
|
| 67 |
+
target_width, target_height = image.size
|
| 68 |
+
closest_ratio = find_closest_aspect_ratio(target_width, target_height)
|
| 69 |
+
|
| 70 |
+
# Get the dimensions from the closest aspect ratio in the dictionary
|
| 71 |
+
new_width, new_height = SD_XL_BASE_RATIOS[closest_ratio]
|
| 72 |
+
|
| 73 |
+
# Resize the image to the new dimensions while preserving the aspect ratio
|
| 74 |
+
resized_image = image.resize((new_width, new_height), Image.ANTIALIAS)
|
| 75 |
+
|
| 76 |
+
return resized_image
|
| 77 |
+
|
| 78 |
+
|
| 79 |
ADAPTER_NAMES = [
|
| 80 |
"TencentARC/t2i-adapter-canny-sdxl-1.0",
|
| 81 |
"TencentARC/t2i-adapter-sketch-sdxl-1.0",
|
|
|
|
| 115 |
return self.model.to(device)
|
| 116 |
|
| 117 |
def __call__(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
| 118 |
+
return self.model(image, detect_resolution=512, image_resolution=1024)
|
| 119 |
|
| 120 |
|
| 121 |
class MidasPreprocessor(Preprocessor):
|
|
|
|
| 331 |
if apply_preprocess:
|
| 332 |
image = self.preprocessor(image)
|
| 333 |
|
| 334 |
+
image = resize_to_closest_aspect_ratio(image)
|
| 335 |
+
|
| 336 |
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 337 |
out = self.pipe(
|
| 338 |
prompt=prompt,
|