danube2024's picture
Update app.py
3a64eb8 verified
import gradio as gr
import torch
import numpy as np
from diffusers import StableDiffusionXLPipeline
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
from PIL import Image, ImageEnhance, ImageOps
device = "cpu" # or "cuda" if you have a GPU
torch_dtype = torch.float32
print("Loading SDXL Base model...")
pipe = StableDiffusionXLPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-1.0",
torch_dtype=torch_dtype
).to(device)
print("Loading bas-relief LoRA weights with PEFT...")
pipe.load_lora_weights(
"KappaNeuro/bas-relief", # The HF repo with BAS-RELIEF.safetensors
weight_name="BAS-RELIEF.safetensors",
peft_backend="peft" # This is crucial
)
print("Loading DPT Depth Model...")
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)
def enhance_depth_map(depth_arr: np.ndarray) -> Image.Image:
d_min, d_max = depth_arr.min(), depth_arr.max()
depth_stretched = (depth_arr - d_min) / (d_max - d_min + 1e-8)
depth_stretched = (depth_stretched * 255).astype(np.uint8)
depth_pil = Image.fromarray(depth_stretched)
depth_pil = ImageOps.autocontrast(depth_pil)
enhancer = ImageEnhance.Sharpness(depth_pil)
depth_pil = enhancer.enhance(2.0)
return depth_pil
def generate_bas_relief_and_depth(prompt):
# Use the token "BAS-RELIEF" so the LoRA triggers
full_prompt = f"BAS-RELIEF {prompt}"
print("Generating image with LoRA style...")
result = pipe(
prompt=full_prompt,
num_inference_steps=15, # reduce if too slow
guidance_scale=7.5,
height=512, # reduce if you still get timeouts
width=512
)
image = result.images[0]
print("Running DPT Depth Estimation...")
inputs = feature_extractor(image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = depth_model(**inputs)
predicted_depth = outputs.predicted_depth
prediction = torch.nn.functional.interpolate(
predicted_depth.unsqueeze(1),
size=image.size[::-1],
mode="bicubic",
align_corners=False
).squeeze()
depth_map_pil = enhance_depth_map(prediction.cpu().numpy())
return image, depth_map_pil
title = "Bas-Relief (SDXL + LoRA) + Depth Map (with PEFT)"
description = (
"Loads stable-diffusion-xl-base-1.0 on CPU, merges LoRA from 'KappaNeuro/bas-relief'. "
"Use 'BAS-RELIEF' token in your prompt to trigger the style, then compute a depth map."
)
iface = gr.Interface(
fn=generate_bas_relief_and_depth,
inputs=gr.Textbox(
label="Description",
placeholder="bas-relief with roman soldier, marble relief, intricately carved"
),
outputs=[gr.Image(label="Bas-Relief Image"), gr.Image(label="Depth Map")],
title=title,
description=description
)
if __name__ == "__main__":
iface.launch()