import gradio as gr
import torch
import numpy as np
from diffusers import StableDiffusionXLPipeline
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
from PIL import Image, ImageEnhance, ImageOps

device = "cpu"  # or "cuda" if you have a GPU
torch_dtype = torch.float32

print("Loading SDXL Base model...")
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch_dtype
).to(device)

print("Loading bas-relief LoRA weights with PEFT...")
pipe.load_lora_weights(
    "KappaNeuro/bas-relief",      # The HF repo with BAS-RELIEF.safetensors
    weight_name="BAS-RELIEF.safetensors",
    peft_backend="peft"          # This is crucial
)

print("Loading DPT Depth Model...")
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-large")
depth_model = DPTForDepthEstimation.from_pretrained("Intel/dpt-large").to(device)

def enhance_depth_map(depth_arr: np.ndarray) -> Image.Image:
    d_min, d_max = depth_arr.min(), depth_arr.max()
    depth_stretched = (depth_arr - d_min) / (d_max - d_min + 1e-8)
    depth_stretched = (depth_stretched * 255).astype(np.uint8)

    depth_pil = Image.fromarray(depth_stretched)
    depth_pil = ImageOps.autocontrast(depth_pil)

    enhancer = ImageEnhance.Sharpness(depth_pil)
    depth_pil = enhancer.enhance(2.0)

    return depth_pil

def generate_bas_relief_and_depth(prompt):
    # Use the token "BAS-RELIEF" so the LoRA triggers
    full_prompt = f"BAS-RELIEF {prompt}"
    print("Generating image with LoRA style...")
    result = pipe(
        prompt=full_prompt,
        num_inference_steps=15,   # reduce if too slow
        guidance_scale=7.5,
        height=512,               # reduce if you still get timeouts
        width=512
    )
    image = result.images[0]

    print("Running DPT Depth Estimation...")
    inputs = feature_extractor(image, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = depth_model(**inputs)
        predicted_depth = outputs.predicted_depth

    prediction = torch.nn.functional.interpolate(
        predicted_depth.unsqueeze(1),
        size=image.size[::-1],
        mode="bicubic",
        align_corners=False
    ).squeeze()

    depth_map_pil = enhance_depth_map(prediction.cpu().numpy())

    return image, depth_map_pil

title = "Bas-Relief (SDXL + LoRA) + Depth Map (with PEFT)"
description = (
    "Loads stable-diffusion-xl-base-1.0 on CPU, merges LoRA from 'KappaNeuro/bas-relief'. "
    "Use 'BAS-RELIEF' token in your prompt to trigger the style, then compute a depth map."
)

iface = gr.Interface(
    fn=generate_bas_relief_and_depth,
    inputs=gr.Textbox(
        label="Description",
        placeholder="bas-relief with roman soldier, marble relief, intricately carved"
    ),
    outputs=[gr.Image(label="Bas-Relief Image"), gr.Image(label="Depth Map")],
    title=title,
    description=description
)

if __name__ == "__main__":
    iface.launch()