File size: 2,428 Bytes
adaa21e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
import torch
import numpy as np
import torchvision.transforms as T
from torchgeo.trainers import SemanticSegmentationTask
import gradio as gr
from PIL import Image
import cv2
def load_model(checkpoint_path):
model = SemanticSegmentationTask.load_from_checkpoint(checkpoint_path)
return model
def preprocess_image(inp):
compose = T.Compose([T.Resize((2048, 2048)), T.ToTensor()])
inp = compose(inp).unsqueeze(0)
return inp
def predict_segmentation(model, inp):
with torch.no_grad():
y_hat = torch.nn.Softmax2d()(model(inp))
return y_hat.squeeze()
def overlay_prediction(input_image, prediction_tensor, alpha=0.5, threshold=0.25):
# Convert the prediction tensor to a PIL image and resize it to match the input image size
prediction_image = T.ToPILImage()(prediction_tensor[0])
prediction_image = prediction_image.resize(input_image.size, resample=Image.NEAREST)
# Apply the cv2.COLORMAP_INFERNO colormap
prediction_image = cv2.applyColorMap(np.array(prediction_image), cv2.COLORMAP_INFERNO)
prediction_image = Image.fromarray(prediction_image).convert("RGBA")
overlay ="RGBA", prediction_image.size, (0, 0, 0, 0))
for x in range(prediction_image.width):
for y in range(prediction_image.height):
r, g, b, a = prediction_image.getpixel((x, y))
if a / 255 > threshold:
overlay.putpixel((x, y), (r, g, b, int(255 * alpha)))
combined_image = Image.alpha_composite(input_image.convert("RGBA"), overlay)
return combined_image.convert("RGB")
def predict(inp):
model = load_model("./unet_resnet50.ckpt")
# Check if a GPU is available and move the model to the GPU if possible
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
preprocessed_image = preprocess_image(inp)
# Move the input tensor to the GPU if available
preprocessed_image =
segmentation_result = predict_segmentation(model, preprocessed_image)
# Move the output tensor back to the CPU for post-processing
segmentation_result = segmentation_result.cpu()
output_image = overlay_prediction(inp, segmentation_result)
return output_image
examples=["./example1.jpg", "./example2.jpg", "./example3.jpg"]
).launch() |