shikunl commited on
Commit
5d9ca62
Β·
1 Parent(s): 5571d3e

Fix no obj error

Browse files
Files changed (1) hide show
  1. label_prettify.py +20 -15
label_prettify.py CHANGED
@@ -34,22 +34,27 @@ def obj_detection_prettify(rgb_path, path_name):
34
 
35
  plt.imshow(rgb)
36
 
37
- num_objs = np.unique(obj_labels)[:-1].max()
38
- plt.imshow(obj_labels, cmap='terrain', vmax=num_objs + 1 / 255., alpha=0.8)
39
- cmap = matplotlib.colormaps.get_cmap('terrain')
40
- for i in np.unique(obj_labels)[:-1]:
41
- obj_idx_all = np.where(obj_labels == i)
42
- x, y = obj_idx_all[1].mean(), obj_idx_all[0].mean()
43
- obj_name = obj_label_map[obj_labels_dict[str(int(i * 255))]]
44
- obj_name = obj_name.split(',')[0]
45
- if islight([c*255 for c in cmap(i / num_objs)[:3]]):
46
- plt.text(x, y, obj_name, c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
47
- else:
48
- plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True)
 
 
 
 
 
49
 
50
- plt.axis('off')
51
- plt.savefig(path_name, bbox_inches='tight', transparent=True, pad_inches=0)
52
- plt.close()
53
 
54
 
55
  def seg_prettify(rgb_path, file_name):
 
34
 
35
  plt.imshow(rgb)
36
 
37
+ if len(np.unique(obj_labels)) == 1:
38
+ plt.axis('off')
39
+ plt.savefig(path_name, bbox_inches='tight', transparent=True, pad_inches=0)
40
+ plt.close()
41
+ else:
42
+ num_objs = np.unique(obj_labels)[:-1].max()
43
+ plt.imshow(obj_labels, cmap='terrain', vmax=num_objs + 1 / 255., alpha=0.8)
44
+ cmap = matplotlib.colormaps.get_cmap('terrain')
45
+ for i in np.unique(obj_labels)[:-1]:
46
+ obj_idx_all = np.where(obj_labels == i)
47
+ x, y = obj_idx_all[1].mean(), obj_idx_all[0].mean()
48
+ obj_name = obj_label_map[obj_labels_dict[str(int(i * 255))]]
49
+ obj_name = obj_name.split(',')[0]
50
+ if islight([c*255 for c in cmap(i / num_objs)[:3]]):
51
+ plt.text(x, y, obj_name, c='black', horizontalalignment='center', verticalalignment='center', clip_on=True)
52
+ else:
53
+ plt.text(x, y, obj_name, c='white', horizontalalignment='center', verticalalignment='center', clip_on=True)
54
 
55
+ plt.axis('off')
56
+ plt.savefig(path_name, bbox_inches='tight', transparent=True, pad_inches=0)
57
+ plt.close()
58
 
59
 
60
  def seg_prettify(rgb_path, file_name):