Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -12,8 +12,6 @@ from PIL import Image, ImageEnhance, ImageOps
|
|
12 |
device = "cpu" # or "cuda" if GPU is available
|
13 |
torch_dtype = torch.float32 # if using CPU or float16 for GPU
|
14 |
|
15 |
-
# --- Load Base SDXL Model ---
|
16 |
-
# (Large model, be sure you have enough memory or use fewer steps)
|
17 |
print("Loading SDXL Base model...")
|
18 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
19 |
"stabilityai/stable-diffusion-xl-base-1.0",
|
@@ -21,86 +19,57 @@ pipe = StableDiffusionXLPipeline.from_pretrained(
|
|
21 |
)
|
22 |
pipe.to(device)
|
23 |
|
24 |
-
# --- Load LoRA Weights from KappaNeuro/bas-relief ---
|
25 |
-
# The safetensors file is named "BAS-RELIEF.safetensors"
|
26 |
-
# This merges the LoRA into the pipeline so you can use it via the "BAS-RELIEF" token
|
27 |
print("Loading bas-relief LoRA weights...")
|
|
|
|
|
28 |
pipe.load_lora_weights(
|
29 |
-
|
30 |
weight_name="BAS-RELIEF.safetensors"
|
31 |
)
|
32 |
|
33 |
-
# --- Load Depth Estimation Model ---
|
34 |
-
# We'll use Intel's DPT for depth. On CPU, it's also relatively large, so be cautious of performance.
|
35 |
print("Loading DPT Depth model...")
|
36 |
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
|
37 |
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
|
38 |
|
39 |
|
40 |
-
############################################
|
41 |
-
# 2. Depth Map Enhancement (PIL-based)
|
42 |
-
############################################
|
43 |
def enhance_depth_map(depth_arr: np.ndarray) -> Image.Image:
|
44 |
"""
|
45 |
-
|
46 |
-
- Auto-contrast to emphasize details
|
47 |
-
- Sharpen edges
|
48 |
"""
|
49 |
d_min, d_max = depth_arr.min(), depth_arr.max()
|
50 |
depth_stretched = (depth_arr - d_min) / (d_max - d_min + 1e-8)
|
51 |
depth_stretched = (depth_stretched * 255).astype(np.uint8)
|
52 |
|
53 |
depth_pil = Image.fromarray(depth_stretched)
|
54 |
-
|
55 |
-
# Auto-contrast
|
56 |
depth_pil = ImageOps.autocontrast(depth_pil)
|
57 |
|
58 |
-
# Sharpen
|
59 |
enhancer = ImageEnhance.Sharpness(depth_pil)
|
60 |
depth_pil = enhancer.enhance(2.0)
|
61 |
|
62 |
return depth_pil
|
63 |
|
64 |
|
65 |
-
############################################
|
66 |
-
# 3. Generation + Depth Inference Function
|
67 |
-
############################################
|
68 |
def generate_bas_relief_and_depth(prompt: str):
|
69 |
-
""
|
70 |
-
1) Generate a 'bas-relief' style image using the LoRA from KappaNeuro/bas-relief.
|
71 |
-
- Must include "BAS-RELIEF" token in the prompt for the style to apply.
|
72 |
-
2) Compute a depth map using Intel/DPT-Large.
|
73 |
-
3) Return (image, depth_map).
|
74 |
-
"""
|
75 |
-
|
76 |
-
# -- Step A: Merge the user's prompt with "BAS-RELIEF" instance token --
|
77 |
-
# You can experiment with different prompt styles:
|
78 |
-
# e.g. "BAS-RELIEF sculpture of a woman in shibari, marble, octane render..."
|
79 |
full_prompt = f"BAS-RELIEF {prompt}"
|
80 |
|
81 |
-
# -- Step B: Generate the image with SDXL + LoRA
|
82 |
-
# Keep resolution modest to avoid timeouts on CPU
|
83 |
print("Generating bas-relief image...")
|
84 |
result = pipe(
|
85 |
prompt=full_prompt,
|
86 |
-
num_inference_steps=15, # Lower
|
87 |
guidance_scale=7.5,
|
88 |
-
height=512,
|
89 |
width=512
|
90 |
)
|
91 |
-
|
92 |
-
# Extract image from pipeline result
|
93 |
generated_image = result.images[0]
|
94 |
|
95 |
-
# -- Step C: Depth Estimation with DPT
|
96 |
print("Running depth estimation...")
|
97 |
inputs = feature_extractor(generated_image, return_tensors="pt").to(device)
|
98 |
-
|
99 |
with torch.no_grad():
|
100 |
outputs = depth_model(**inputs)
|
101 |
-
predicted_depth = outputs.predicted_depth
|
102 |
|
103 |
-
# Resize to match original image
|
104 |
prediction = torch.nn.functional.interpolate(
|
105 |
predicted_depth.unsqueeze(1),
|
106 |
size=generated_image.size[::-1],
|
@@ -109,31 +78,24 @@ def generate_bas_relief_and_depth(prompt: str):
|
|
109 |
).squeeze(0)
|
110 |
|
111 |
depth_arr = prediction.cpu().numpy()
|
112 |
-
|
113 |
|
114 |
-
return generated_image,
|
115 |
|
116 |
|
117 |
-
|
118 |
-
# 4. Gradio Interface
|
119 |
-
############################################
|
120 |
-
title = "Bas-Relief with SDXL + LoRA + Depth Map"
|
121 |
description = (
|
122 |
-
"
|
123 |
-
"
|
124 |
-
"Lower resolution or fewer steps if you get timeouts."
|
125 |
)
|
126 |
|
127 |
iface = gr.Interface(
|
128 |
fn=generate_bas_relief_and_depth,
|
129 |
inputs=gr.Textbox(
|
130 |
label="Describe your scene/style",
|
131 |
-
placeholder="sculpture of a woman in shibari, marble, intricate details"
|
132 |
),
|
133 |
-
outputs=[
|
134 |
-
gr.Image(label="Bas-Relief Image"),
|
135 |
-
gr.Image(label="Depth Map"),
|
136 |
-
],
|
137 |
title=title,
|
138 |
description=description
|
139 |
)
|
|
|
12 |
device = "cpu" # or "cuda" if GPU is available
|
13 |
torch_dtype = torch.float32 # if using CPU or float16 for GPU
|
14 |
|
|
|
|
|
15 |
print("Loading SDXL Base model...")
|
16 |
pipe = StableDiffusionXLPipeline.from_pretrained(
|
17 |
"stabilityai/stable-diffusion-xl-base-1.0",
|
|
|
19 |
)
|
20 |
pipe.to(device)
|
21 |
|
|
|
|
|
|
|
22 |
print("Loading bas-relief LoRA weights...")
|
23 |
+
# IMPORTANT: Pass the first argument as a string to the repo or path,
|
24 |
+
# and `weight_name` as a kwarg. That matches the actual function signature.
|
25 |
pipe.load_lora_weights(
|
26 |
+
"KappaNeuro/bas-relief", # repo / path
|
27 |
weight_name="BAS-RELIEF.safetensors"
|
28 |
)
|
29 |
|
|
|
|
|
30 |
print("Loading DPT Depth model...")
|
31 |
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
|
32 |
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
|
33 |
|
34 |
|
|
|
|
|
|
|
35 |
def enhance_depth_map(depth_arr: np.ndarray) -> Image.Image:
|
36 |
"""
|
37 |
+
Normalize depth to [0, 255], auto-contrast, and sharpen.
|
|
|
|
|
38 |
"""
|
39 |
d_min, d_max = depth_arr.min(), depth_arr.max()
|
40 |
depth_stretched = (depth_arr - d_min) / (d_max - d_min + 1e-8)
|
41 |
depth_stretched = (depth_stretched * 255).astype(np.uint8)
|
42 |
|
43 |
depth_pil = Image.fromarray(depth_stretched)
|
|
|
|
|
44 |
depth_pil = ImageOps.autocontrast(depth_pil)
|
45 |
|
|
|
46 |
enhancer = ImageEnhance.Sharpness(depth_pil)
|
47 |
depth_pil = enhancer.enhance(2.0)
|
48 |
|
49 |
return depth_pil
|
50 |
|
51 |
|
|
|
|
|
|
|
52 |
def generate_bas_relief_and_depth(prompt: str):
|
53 |
+
# We prepend "BAS-RELIEF" to ensure the LoRA style is triggered.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
full_prompt = f"BAS-RELIEF {prompt}"
|
55 |
|
|
|
|
|
56 |
print("Generating bas-relief image...")
|
57 |
result = pipe(
|
58 |
prompt=full_prompt,
|
59 |
+
num_inference_steps=15, # Lower for speed on CPU
|
60 |
guidance_scale=7.5,
|
61 |
+
height=512,
|
62 |
width=512
|
63 |
)
|
|
|
|
|
64 |
generated_image = result.images[0]
|
65 |
|
|
|
66 |
print("Running depth estimation...")
|
67 |
inputs = feature_extractor(generated_image, return_tensors="pt").to(device)
|
|
|
68 |
with torch.no_grad():
|
69 |
outputs = depth_model(**inputs)
|
70 |
+
predicted_depth = outputs.predicted_depth
|
71 |
|
72 |
+
# Resize depth map to match original image
|
73 |
prediction = torch.nn.functional.interpolate(
|
74 |
predicted_depth.unsqueeze(1),
|
75 |
size=generated_image.size[::-1],
|
|
|
78 |
).squeeze(0)
|
79 |
|
80 |
depth_arr = prediction.cpu().numpy()
|
81 |
+
depth_pil = enhance_depth_map(depth_arr)
|
82 |
|
83 |
+
return generated_image, depth_pil
|
84 |
|
85 |
|
86 |
+
title = "Bas-Relief (SDXL + LoRA) + Depth Map"
|
|
|
|
|
|
|
87 |
description = (
|
88 |
+
"Load SDXL base on CPU, apply 'BAS-RELIEF.safetensors' LoRA from KappaNeuro/bas-relief. "
|
89 |
+
"Then run DPT for depth estimation."
|
|
|
90 |
)
|
91 |
|
92 |
iface = gr.Interface(
|
93 |
fn=generate_bas_relief_and_depth,
|
94 |
inputs=gr.Textbox(
|
95 |
label="Describe your scene/style",
|
96 |
+
placeholder="e.g., 'sculpture of a woman in shibari, marble, intricate details'"
|
97 |
),
|
98 |
+
outputs=[gr.Image(label="Bas-Relief Image"), gr.Image(label="Depth Map")],
|
|
|
|
|
|
|
99 |
title=title,
|
100 |
description=description
|
101 |
)
|