depthpro-free / app.py
mminju's picture
Upload app.py
62da546 verified
# app.py (HF Spaces: SDK=gradio)
import io, base64, numpy as np, torch, gradio as gr
from PIL import Image
from transformers import AutoImageProcessor, DepthProForDepthEstimation
device = "cuda" if torch.cuda.is_available() else "cpu"
_proc = None
_model = None
def _lazy_init():
global _proc, _model
if _proc is None:
_proc = AutoImageProcessor.from_pretrained("apple/DepthPro-hf")
if _model is None:
_model = DepthProForDepthEstimation.from_pretrained("apple/DepthPro-hf").to(device).eval()
def _infer(pil_img: Image.Image):
_lazy_init()
H, W = pil_img.height, pil_img.width
inputs = _proc(images=pil_img.convert("RGB"), return_tensors="pt").to(device)
with torch.no_grad():
outputs = _model(**inputs)
post = _proc.post_process_depth_estimation(outputs, target_sizes=[(H, W)])[0]
depth = post["predicted_depth"].float().cpu().numpy()
fov = float(post.get("field_of_view", 0.0))
focal = float(post.get("focal_length", 0.0))
return depth, H, W, fov, focal
# (A) API ํ•จ์ˆ˜: JSON ๋ฐ˜ํ™˜
def depth_api(img: Image.Image):
depth, H, W, fov, focal = _infer(img)
depth_b64 = base64.b64encode(depth.astype(np.float32).tobytes()).decode("ascii")
return {
"height": int(H),
"width": int(W),
"focal_px": float(focal),
"field_of_view": float(fov),
"depth_flat": depth_b64
}
# (B) ํ”„๋ฆฌ๋ทฐ์šฉ UI
def preview(img: Image.Image):
depth, *_ = _infer(img)
v = depth[np.isfinite(depth)]
lo, hi = (np.percentile(v, 1), np.percentile(v, 99)) if v.size else (0, 1)
norm = np.clip((depth - lo) / max(1e-6, hi - lo), 0, 1)
return Image.fromarray((norm * 255).astype(np.uint8))
# ๐Ÿ”น Blocks(UI) ๋งŒ๋“ค๊ธฐ
with gr.Blocks() as ui:
gr.Markdown("## DepthPro-hf (CPU, Free Space)\n- REST API: **POST /api/predict/depth** (JSON base64)")
with gr.Row():
inp = gr.Image(type="pil", label="Input")
out = gr.Image(label="Depth (preview)")
gr.Button("Run").click(preview, inp, out)
# ๐Ÿ”น API ์ธํ„ฐํŽ˜์ด์Šค (REST ๊ฒฝ๋กœ: /api/predict/depth)
api = gr.Interface(
fn=depth_api,
inputs=gr.Image(type="pil"),
outputs=gr.JSON(),
api_name="depth"
)
# โœ… ๋‘ ๊ฐœ๋ฅผ ํ•˜๋‚˜์˜ ์•ฑ์œผ๋กœ ํ•ฉ์น˜๊ธฐ
demo = gr.TabbedInterface([ui, api], tab_names=["UI", "api"])