Spaces:
Runtime error
Runtime error
bram-w
commited on
Commit
·
93d448d
1
Parent(s):
47a05f1
safety check
Browse files- edict_functions.py +29 -1
edict_functions.py
CHANGED
|
@@ -17,6 +17,8 @@ import os
|
|
| 17 |
from torchvision import datasets
|
| 18 |
import pickle
|
| 19 |
|
|
|
|
|
|
|
| 20 |
# StableDiffusion P2P implementation originally from https://github.com/bloc97/CrossAttentionControl
|
| 21 |
use_half_prec = True
|
| 22 |
if use_half_prec:
|
|
@@ -66,7 +68,30 @@ else:
|
|
| 66 |
clip.double().to(device)
|
| 67 |
print("Loaded all models")
|
| 68 |
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
|
| 72 |
def EDICT_editing(im_path,
|
|
@@ -597,6 +622,9 @@ def baseline_stablediffusion(prompt="",
|
|
| 597 |
|
| 598 |
image = (image / 2 + 0.5).clamp(0, 1)
|
| 599 |
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
|
|
|
|
|
|
|
|
|
| 600 |
image = (image[0] * 255).round().astype("uint8")
|
| 601 |
return Image.fromarray(image)
|
| 602 |
####################################
|
|
|
|
| 17 |
from torchvision import datasets
|
| 18 |
import pickle
|
| 19 |
|
| 20 |
+
|
| 21 |
+
|
| 22 |
# StableDiffusion P2P implementation originally from https://github.com/bloc97/CrossAttentionControl
|
| 23 |
use_half_prec = True
|
| 24 |
if use_half_prec:
|
|
|
|
| 68 |
clip.double().to(device)
|
| 69 |
print("Loaded all models")
|
| 70 |
|
| 71 |
+
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
| 72 |
+
from transformers import AutoFeatureExtractor
|
| 73 |
+
# load safety model
|
| 74 |
+
safety_model_id = "CompVis/stable-diffusion-safety-checker"
|
| 75 |
+
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
|
| 76 |
+
safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)
|
| 77 |
+
def load_replacement(x):
|
| 78 |
+
try:
|
| 79 |
+
hwc = x.shape
|
| 80 |
+
y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
|
| 81 |
+
y = (np.array(y)/255.0).astype(x.dtype)
|
| 82 |
+
assert y.shape == x.shape
|
| 83 |
+
return y
|
| 84 |
+
except Exception:
|
| 85 |
+
return x
|
| 86 |
+
def check_safety(x_image):
|
| 87 |
+
safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
|
| 88 |
+
x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)
|
| 89 |
+
assert x_checked_image.shape[0] == len(has_nsfw_concept)
|
| 90 |
+
for i in range(len(has_nsfw_concept)):
|
| 91 |
+
if has_nsfw_concept[i]:
|
| 92 |
+
# x_checked_image[i] = load_replacement(x_checked_image[i])
|
| 93 |
+
x_checked_image[i] *= 0 # load_replacement(x_checked_image[i])
|
| 94 |
+
return x_checked_image, has_nsfw_concept
|
| 95 |
|
| 96 |
|
| 97 |
def EDICT_editing(im_path,
|
|
|
|
| 622 |
|
| 623 |
image = (image / 2 + 0.5).clamp(0, 1)
|
| 624 |
image = image.cpu().permute(0, 2, 3, 1).numpy()
|
| 625 |
+
|
| 626 |
+
image, _ = check_safety(image)
|
| 627 |
+
|
| 628 |
image = (image[0] * 255).round().astype("uint8")
|
| 629 |
return Image.fromarray(image)
|
| 630 |
####################################
|