Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import torch
|
3 |
+
from torchvision import transforms
|
4 |
+
from PIL import Image
|
5 |
+
import numpy as np
|
6 |
+
from unet_model import UNet
|
7 |
+
from huggingface_hub import hf_hub_download
|
8 |
+
|
9 |
+
# Load trained model weights from Hugging Face Hub
|
10 |
+
weights_path = hf_hub_download(
|
11 |
+
repo_id="Vizuara/unet-crack-segmentation", # ensure this matches your repo
|
12 |
+
filename="unet_weights_v2.pth" # make sure this file exists in repo
|
13 |
+
)
|
14 |
+
|
15 |
+
# Initialize and load model
|
16 |
+
model = UNet()
|
17 |
+
model.load_state_dict(torch.load(weights_path, map_location="cpu"))
|
18 |
+
model.eval()
|
19 |
+
|
20 |
+
# Preprocessing: same as training
|
21 |
+
IMG_HEIGHT, IMG_WIDTH = 128, 128
|
22 |
+
transform = transforms.Compose([
|
23 |
+
transforms.Resize((IMG_HEIGHT, IMG_WIDTH)),
|
24 |
+
transforms.ToTensor()
|
25 |
+
])
|
26 |
+
|
27 |
+
def predict(image):
|
28 |
+
orig_w, orig_h = image.size # original size of uploaded image
|
29 |
+
img = transform(image).unsqueeze(0) # (1,3,128,128)
|
30 |
+
with torch.no_grad():
|
31 |
+
pred = model(img)
|
32 |
+
|
33 |
+
mask = pred.squeeze(0).squeeze(0).cpu().numpy()
|
34 |
+
mask = (mask * 255).astype(np.uint8) # grayscale mask
|
35 |
+
|
36 |
+
# Resize back to original size
|
37 |
+
mask_img = Image.fromarray(mask).resize((orig_w, orig_h), Image.NEAREST)
|
38 |
+
return mask_img
|
39 |
+
|
40 |
+
|
41 |
+
# Gradio interface
|
42 |
+
demo = gr.Interface(
|
43 |
+
fn=predict,
|
44 |
+
inputs=gr.Image(type="pil"),
|
45 |
+
outputs=gr.Image(type="pil"),
|
46 |
+
title="UNet Crack Segmentation",
|
47 |
+
description="Upload a concrete surface image to get predicted crack mask"
|
48 |
+
)
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
demo.launch()
|