File size: 2,778 Bytes
5464cad
 
 
429658f
5464cad
 
 
 
 
 
 
28ac920
 
 
 
5464cad
 
 
 
28ac920
5464cad
 
 
 
 
 
28ac920
5464cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28ac920
 
 
5464cad
 
 
429658f
5464cad
 
28ac920
5464cad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28ac920
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import torch.nn as nn
from torchvision import transforms
from safetensors.torch import load_file
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec

from model import U2Net

if torch.cuda.is_available(): device = 'cuda'
elif torch.backends.mps.is_available(): device = 'mps'
else: device = 'cpu'
device = torch.device(device)

def preprocess_image(image_path):
    img = Image.open(image_path).convert('RGB')
    preprocess = transforms.Compose([
        transforms.Resize((512, 512), interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    img = preprocess(img).unsqueeze(0).to(device)
    return img

def run_inference(model, image_path, threshold=None):
    input_img = preprocess_image(image_path)
    with torch.no_grad():
        d1, *_ = model(input_img)
        pred = torch.sigmoid(d1)
        pred = pred[0, :, :].cpu().numpy()
    
    pred = (pred - pred.min()) / (pred.max() - pred.min())
    if threshold is not None:
        pred = (pred > threshold).astype(np.uint8) * 255
    else:
        pred = (pred * 255).astype(np.uint8)
    return pred

def overlay_segmentation(original_image, binary_mask, alpha=0.5):
    original_image = Image.open(original_image).convert('RGB').resize((512, 512), Image.BILINEAR)
    original_image_np = np.array(original_image)
    overlay = np.zeros_like(original_image_np)
    overlay[:, :, 0] = binary_mask
    overlay_image = (1 - alpha) * original_image_np + alpha * overlay
    overlay_image = overlay_image.astype(np.uint8)
    return overlay_image


if __name__ == '__main__':
    # ---
    model_path = '../testing/u2net-duts-msra.safetensors'
    filename = input('Filename: ')
    image_path = f'../content_images/{filename}'
    # ---
    model = U2Net().to(device)
    model = nn.DataParallel(model)
    model.load_state_dict(load_file(model_path, device=device.type))
    model.eval()

    mask = run_inference(model, image_path)
    mask_with_threshold = run_inference(model, image_path, threshold=0.7)
    
    fig = plt.figure(figsize=(10, 10))
    gs = GridSpec(2, 2, figure=fig, wspace=0, hspace=0)
    
    images = [
        Image.open(image_path).resize((512, 512)),
        mask,
        overlay_segmentation(image_path, mask_with_threshold),
        mask_with_threshold
    ]
    
    for i, img in enumerate(images):
        ax = fig.add_subplot(gs[i // 2, i % 2])
        ax.imshow(img, cmap='gray' if i % 2 != 0 else None)
        ax.axis('off')

    plt.subplots_adjust(left=0, right=1, top=1, bottom=0)
    plt.savefig('../testing/inference-output.jpg', format='jpg', bbox_inches='tight', pad_inches=0)