|
import gradio as gr |
|
import torch |
|
import numpy as np |
|
from diffusers import StableDiffusionXLPipeline |
|
from transformers import DPTFeatureExtractor, DPTForDepthEstimation |
|
from PIL import Image, ImageEnhance, ImageOps |
|
|
|
device = "cpu" |
|
torch_dtype = torch.float32 |
|
|
|
print("Loading SDXL Base model...") |
|
pipe = StableDiffusionXLPipeline.from_pretrained( |
|
"stabilityai/stable-diffusion-xl-base-1.0", |
|
torch_dtype=torch_dtype |
|
).to(device) |
|
|
|
print("Loading bas-relief LoRA weights with PEFT...") |
|
pipe.load_lora_weights( |
|
"KappaNeuro/bas-relief", |
|
weight_name="BAS-RELIEF.safetensors", |
|
peft_backend="peft" |
|
) |
|
|
|
print("Loading DPT Depth Model...") |
|
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large") |
|
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device) |
|
|
|
def enhance_depth_map(depth_arr: np.ndarray) -> Image.Image: |
|
d_min, d_max = depth_arr.min(), depth_arr.max() |
|
depth_stretched = (depth_arr - d_min) / (d_max - d_min + 1e-8) |
|
depth_stretched = (depth_stretched * 255).astype(np.uint8) |
|
|
|
depth_pil = Image.fromarray(depth_stretched) |
|
depth_pil = ImageOps.autocontrast(depth_pil) |
|
|
|
enhancer = ImageEnhance.Sharpness(depth_pil) |
|
depth_pil = enhancer.enhance(2.0) |
|
|
|
return depth_pil |
|
|
|
def generate_bas_relief_and_depth(prompt): |
|
|
|
full_prompt = f"BAS-RELIEF {prompt}" |
|
print("Generating image with LoRA style...") |
|
result = pipe( |
|
prompt=full_prompt, |
|
num_inference_steps=15, |
|
guidance_scale=7.5, |
|
height=512, |
|
width=512 |
|
) |
|
image = result.images[0] |
|
|
|
print("Running DPT Depth Estimation...") |
|
inputs = feature_extractor(image, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs = depth_model(**inputs) |
|
predicted_depth = outputs.predicted_depth |
|
|
|
prediction = torch.nn.functional.interpolate( |
|
predicted_depth.unsqueeze(1), |
|
size=image.size[::-1], |
|
mode="bicubic", |
|
align_corners=False |
|
).squeeze() |
|
|
|
depth_map_pil = enhance_depth_map(prediction.cpu().numpy()) |
|
|
|
return image, depth_map_pil |
|
|
|
title = "Bas-Relief (SDXL + LoRA) + Depth Map (with PEFT)" |
|
description = ( |
|
"Loads stable-diffusion-xl-base-1.0 on CPU, merges LoRA from 'KappaNeuro/bas-relief'. " |
|
"Use 'BAS-RELIEF' token in your prompt to trigger the style, then compute a depth map." |
|
) |
|
|
|
iface = gr.Interface( |
|
fn=generate_bas_relief_and_depth, |
|
inputs=gr.Textbox( |
|
label="Description", |
|
placeholder="bas-relief with roman soldier, marble relief, intricately carved" |
|
), |
|
outputs=[gr.Image(label="Bas-Relief Image"), gr.Image(label="Depth Map")], |
|
title=title, |
|
description=description |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|