File size: 2,428 Bytes
adaa21e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import numpy as np
import torchvision.transforms as T
from torchgeo.trainers import SemanticSegmentationTask
import gradio as gr
from PIL import Image
import cv2

def load_model(checkpoint_path):
    model = SemanticSegmentationTask.load_from_checkpoint(checkpoint_path)
    return model

def preprocess_image(inp):
    compose = T.Compose([T.Resize((2048, 2048)), T.ToTensor()])
    inp = compose(inp).unsqueeze(0)
    return inp

def predict_segmentation(model, inp):
    with torch.no_grad():
        y_hat = torch.nn.Softmax2d()(model(inp))
    return y_hat.squeeze()

def overlay_prediction(input_image, prediction_tensor, alpha=0.5, threshold=0.25):
    # Convert the prediction tensor to a PIL image and resize it to match the input image size
    prediction_image = T.ToPILImage()(prediction_tensor[0])
    prediction_image = prediction_image.resize(input_image.size, resample=Image.NEAREST)
    
    # Apply the cv2.COLORMAP_INFERNO colormap
    prediction_image = cv2.applyColorMap(np.array(prediction_image), cv2.COLORMAP_INFERNO)
    prediction_image = Image.fromarray(prediction_image).convert("RGBA")
    
    overlay = Image.new("RGBA", prediction_image.size, (0, 0, 0, 0))
    
    for x in range(prediction_image.width):
        for y in range(prediction_image.height):
            r, g, b, a = prediction_image.getpixel((x, y))
            if a / 255 > threshold:
                overlay.putpixel((x, y), (r, g, b, int(255 * alpha)))

    combined_image = Image.alpha_composite(input_image.convert("RGBA"), overlay)
    return combined_image.convert("RGB")

def predict(inp):
    model = load_model("./unet_resnet50.ckpt")
    # Check if a GPU is available and move the model to the GPU if possible
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    preprocessed_image = preprocess_image(inp)
    # Move the input tensor to the GPU if available
    preprocessed_image = preprocessed_image.to(device)
    segmentation_result = predict_segmentation(model, preprocessed_image)
    # Move the output tensor back to the CPU for post-processing
    segmentation_result = segmentation_result.cpu()
    output_image = overlay_prediction(inp, segmentation_result)
    return output_image

gr.Interface(
    fn=predict,
    inputs=gr.inputs.Image(type="pil"),
    outputs="image",
    examples=["./example1.jpg", "./example2.jpg", "./example3.jpg"]
).launch()