File size: 3,215 Bytes
a950ee6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
import os
import torch
import monai.transforms as transforms

def draw_result(category, image, bboxes, points, logits, gt3D):
    zoom_out_transform = transforms.Compose([
        transforms.AddChanneld(keys=["image", "label", "logits"]),
        transforms.Resized(keys=["image", "label", "logits"], spatial_size=(32,256,256), mode='nearest-exact')
    ])
    print(image.shape, gt3D.shape, logits.shape)
    image = image[0,0]
    post_item = zoom_out_transform({
        'image': image,
        'label': gt3D,
        'logits': logits
    })
    image, gt3D, logits = post_item['image'][0], post_item['label'][0], post_item['logits'][0]
    preds = torch.sigmoid(logits)
    preds = (preds > 0.5).int()

    root_dir=os.path.join(f'./fig_examples/{category}/') 

    if not os.path.exists(root_dir):
        os.makedirs(root_dir)
    if bboxes is not None:
        x1, y1, z1, x2, y2, z2 = bboxes[0].cpu().numpy()
    if points is not None:
        points = (points[0].cpu().numpy(), points[1].cpu().numpy())
        points_ax = points[0]   # [n, 3]
        points_label = points[1] # [n]
        # print(points_ax.shape, points_label.shape)

    for j in range(image.shape[0]):
        img_2d = image[j, :, :].detach().cpu().numpy()
        preds_2d = preds[j, :, :].detach().cpu().numpy()
        label_2d = gt3D[j, :, :].detach().cpu().numpy()
        # if np.sum(label_2d) == 0 and np.sum(preds_2d) == 0:
        #     continue

        # orginal img
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3)
        ax1.imshow(img_2d, cmap='gray')
        ax1.set_title('Image with prompt') 
        ax1.axis('off')

        # gt
        ax2.imshow(img_2d, cmap='gray')
        show_mask(label_2d, ax2)
        ax2.set_title('Ground truth') 
        ax2.axis('off')

        # preds
        ax3.imshow(img_2d, cmap='gray')
        show_mask(preds_2d, ax3)
        ax3.set_title('Prediction') 
        ax3.axis('off')

        # boxes
        if bboxes is not None:
            if j >= x1 and j <= x2:
                show_box((z1, y1, z2, y2), ax1)
        # points
        if points is not None:
            for point_idx in range(points_label.shape[0]):
                point = points_ax[point_idx]
                label = points_label[point_idx] # [1]
                if j == point[0]:
                    show_points(point, label, ax1)
        
        fig.subplots_adjust(left=0, right=1, bottom=0, top=1, wspace=0, hspace=0)
        plt.savefig(os.path.join(root_dir, f'{category}_{j}.png'), bbox_inches='tight')
        plt.close()

def show_mask(mask, ax):
    color = np.array([251/255, 252/255, 30/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image, alpha=0.35)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2))

def show_points(points_ax, points_label, ax):
    print('draw point')
    color = 'red' if points_label == 0 else 'blue'
    ax.scatter(points_ax[2], points_ax[1], c=color, marker='o', s=200)