|  | import argparse | 
					
						
						|  | import cv2 | 
					
						
						|  | import os | 
					
						
						|  | import json | 
					
						
						|  | import numpy as np | 
					
						
						|  | from PIL import Image as PILImage | 
					
						
						|  | import joblib | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def mask_nms(masks, bbox_scores, instances_confidence_threshold=0.5, overlap_threshold=0.7): | 
					
						
						|  | """ | 
					
						
						|  | NMS-like procedure used in Panoptic Segmentation | 
					
						
						|  | Remove the overlap areas of different instances in Instance Segmentation | 
					
						
						|  | """ | 
					
						
						|  | panoptic_seg = np.zeros(masks.shape[:2], dtype=np.uint8) | 
					
						
						|  | sorted_inds = list(range(len(bbox_scores))) | 
					
						
						|  | current_segment_id = 0 | 
					
						
						|  | segments_score = [] | 
					
						
						|  |  | 
					
						
						|  | for inst_id in sorted_inds: | 
					
						
						|  | score = bbox_scores[inst_id] | 
					
						
						|  | if score < instances_confidence_threshold: | 
					
						
						|  | break | 
					
						
						|  | mask = masks[:, :, inst_id] | 
					
						
						|  | mask_area = mask.sum() | 
					
						
						|  |  | 
					
						
						|  | if mask_area == 0: | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | intersect = (mask > 0) & (panoptic_seg > 0) | 
					
						
						|  | intersect_area = intersect.sum() | 
					
						
						|  |  | 
					
						
						|  | if intersect_area * 1.0 / mask_area > overlap_threshold: | 
					
						
						|  | continue | 
					
						
						|  |  | 
					
						
						|  | if intersect_area > 0: | 
					
						
						|  | mask = mask & (panoptic_seg == 0) | 
					
						
						|  |  | 
					
						
						|  | current_segment_id += 1 | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | panoptic_seg = np.where(mask == 0, panoptic_seg, current_segment_id) | 
					
						
						|  | segments_score.append(score) | 
					
						
						|  |  | 
					
						
						|  | return panoptic_seg, segments_score | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def extend(si, sj, instance_label, global_label, panoptic_seg_mask, class_map): | 
					
						
						|  | """ | 
					
						
						|  | """ | 
					
						
						|  | directions = [[-1, 0], [0, 1], [1, 0], [0, -1], | 
					
						
						|  | [1, 1], [1, -1], [-1, 1], [-1, -1]] | 
					
						
						|  |  | 
					
						
						|  | inst_class = instance_label[si, sj] | 
					
						
						|  | human_class = panoptic_seg_mask[si, sj] | 
					
						
						|  | global_class = class_map[inst_class] | 
					
						
						|  | queue = [[si, sj]] | 
					
						
						|  |  | 
					
						
						|  | while len(queue) != 0: | 
					
						
						|  | cur = queue[0] | 
					
						
						|  | queue.pop(0) | 
					
						
						|  |  | 
					
						
						|  | for direction in directions: | 
					
						
						|  | ni = cur[0] + direction[0] | 
					
						
						|  | nj = cur[1] + direction[1] | 
					
						
						|  |  | 
					
						
						|  | if ni >= 0 and nj >= 0 and \ | 
					
						
						|  | ni < instance_label.shape[0] and \ | 
					
						
						|  | nj < instance_label.shape[1] and \ | 
					
						
						|  | instance_label[ni, nj] == 0 and \ | 
					
						
						|  | global_label[ni, nj] == global_class: | 
					
						
						|  | instance_label[ni, nj] = inst_class | 
					
						
						|  |  | 
					
						
						|  | panoptic_seg_mask[ni, nj] = human_class | 
					
						
						|  | queue.append([ni, nj]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def refine(instance_label, panoptic_seg_mask, global_label, class_map): | 
					
						
						|  | """ | 
					
						
						|  | Inputs: | 
					
						
						|  | [ instance_label ] | 
					
						
						|  | np.array() with shape [h, w] | 
					
						
						|  | [ global_label ] with shape [h, w] | 
					
						
						|  | np.array() | 
					
						
						|  | """ | 
					
						
						|  | for i in range(instance_label.shape[0]): | 
					
						
						|  | for j in range(instance_label.shape[1]): | 
					
						
						|  | if instance_label[i, j] != 0: | 
					
						
						|  | extend(i, j, instance_label, global_label, panoptic_seg_mask, class_map) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_palette(num_cls): | 
					
						
						|  | """ Returns the color map for visualizing the segmentation mask. | 
					
						
						|  | Inputs: | 
					
						
						|  | =num_cls= | 
					
						
						|  | Number of classes. | 
					
						
						|  | Returns: | 
					
						
						|  | The color map. | 
					
						
						|  | """ | 
					
						
						|  | n = num_cls | 
					
						
						|  | palette = [0] * (n * 3) | 
					
						
						|  | for j in range(0, n): | 
					
						
						|  | lab = j | 
					
						
						|  | palette[j * 3 + 0] = 0 | 
					
						
						|  | palette[j * 3 + 1] = 0 | 
					
						
						|  | palette[j * 3 + 2] = 0 | 
					
						
						|  | i = 0 | 
					
						
						|  | while lab: | 
					
						
						|  | palette[j * 3 + 0] |= (((lab >> 0) & 1) << (7 - i)) | 
					
						
						|  | palette[j * 3 + 1] |= (((lab >> 1) & 1) << (7 - i)) | 
					
						
						|  | palette[j * 3 + 2] |= (((lab >> 2) & 1) << (7 - i)) | 
					
						
						|  | i += 1 | 
					
						
						|  | lab >>= 3 | 
					
						
						|  | return palette | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def patch2img_output(patch_dir, img_name, img_height, img_width, bbox, bbox_type, num_class): | 
					
						
						|  | """transform bbox patch outputs to image output""" | 
					
						
						|  | assert bbox_type == 'gt' or 'msrcnn' | 
					
						
						|  | output = np.zeros((img_height, img_width, num_class), dtype='float') | 
					
						
						|  | output[:, :, 0] = np.inf | 
					
						
						|  | count_predictions = np.zeros((img_height, img_width, num_class), dtype='int32') | 
					
						
						|  | for i in range(len(bbox)): | 
					
						
						|  | file_path = os.path.join(patch_dir, os.path.splitext(img_name)[0] + '_' + str(i + 1) + '_' + bbox_type + '.npy') | 
					
						
						|  | bbox_output = np.load(file_path) | 
					
						
						|  | output[bbox[i][1]:bbox[i][3] + 1, bbox[i][0]:bbox[i][2] + 1, 1:] += bbox_output[:, :, 1:] | 
					
						
						|  | count_predictions[bbox[i][1]:bbox[i][3] + 1, bbox[i][0]:bbox[i][2] + 1, 1:] += 1 | 
					
						
						|  | output[bbox[i][1]:bbox[i][3] + 1, bbox[i][0]:bbox[i][2] + 1, 0] \ | 
					
						
						|  | = np.minimum(output[bbox[i][1]:bbox[i][3] + 1, bbox[i][0]:bbox[i][2] + 1, 0], bbox_output[:, :, 0]) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | count_predictions[count_predictions == 0] = 1 | 
					
						
						|  | return output / count_predictions | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_instance(cat_gt, panoptic_seg_mask): | 
					
						
						|  | """ | 
					
						
						|  | """ | 
					
						
						|  | instance_gt = np.zeros_like(cat_gt, dtype=np.uint8) | 
					
						
						|  | num_humans = len(np.unique(panoptic_seg_mask)) - 1 | 
					
						
						|  | class_map = {} | 
					
						
						|  |  | 
					
						
						|  | total_part_num = 0 | 
					
						
						|  | for id in range(1, num_humans + 1): | 
					
						
						|  | human_part_label = np.where(panoptic_seg_mask == id, cat_gt, 0).astype(np.uint8) | 
					
						
						|  |  | 
					
						
						|  | part_classes = np.unique(human_part_label) | 
					
						
						|  |  | 
					
						
						|  | exceed = False | 
					
						
						|  | for part_id in part_classes: | 
					
						
						|  | if part_id == 0: | 
					
						
						|  | continue | 
					
						
						|  | total_part_num += 1 | 
					
						
						|  |  | 
					
						
						|  | if total_part_num > 255: | 
					
						
						|  | print("total_part_num exceed, return current instance map: {}".format(total_part_num)) | 
					
						
						|  | exceed = True | 
					
						
						|  | break | 
					
						
						|  | class_map[total_part_num] = part_id | 
					
						
						|  | instance_gt[np.where(human_part_label == part_id)] = total_part_num | 
					
						
						|  | if exceed: | 
					
						
						|  | break | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ori_cur_labels = np.unique(instance_gt) | 
					
						
						|  | total_num_label = len(ori_cur_labels) | 
					
						
						|  | if instance_gt.max() + 1 != total_num_label: | 
					
						
						|  | for label in range(1, total_num_label): | 
					
						
						|  | instance_gt[instance_gt == ori_cur_labels[label]] = label | 
					
						
						|  |  | 
					
						
						|  | final_class_map = {} | 
					
						
						|  | for label in range(1, total_num_label): | 
					
						
						|  | if label >= 1: | 
					
						
						|  | final_class_map[label] = class_map[ori_cur_labels[label]] | 
					
						
						|  |  | 
					
						
						|  | return instance_gt, final_class_map | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def compute_confidence(im_name, feature_map, class_map, | 
					
						
						|  | instance_label, output_dir, | 
					
						
						|  | panoptic_seg_mask, seg_score_list): | 
					
						
						|  | """ | 
					
						
						|  | """ | 
					
						
						|  | conf_file = open(os.path.join(output_dir, os.path.splitext(im_name)[0] + '.txt'), 'w') | 
					
						
						|  |  | 
					
						
						|  | weighted_map = np.zeros_like(feature_map[:, :, 0]) | 
					
						
						|  | for index, score in enumerate(seg_score_list): | 
					
						
						|  | weighted_map += (panoptic_seg_mask == index + 1) * score | 
					
						
						|  |  | 
					
						
						|  | for label in class_map.keys(): | 
					
						
						|  | cls = class_map[label] | 
					
						
						|  | confidence = feature_map[:, :, cls].reshape(-1)[np.where(instance_label.reshape(-1) == label)] | 
					
						
						|  | confidence = (weighted_map * feature_map[:, :, cls].copy()).reshape(-1)[ | 
					
						
						|  | np.where(instance_label.reshape(-1) == label)] | 
					
						
						|  |  | 
					
						
						|  | confidence = confidence.sum() / len(confidence) | 
					
						
						|  | conf_file.write('{} {}\n'.format(cls, confidence)) | 
					
						
						|  |  | 
					
						
						|  | conf_file.close() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def result_saving(fused_output, img_name, img_height, img_width, output_dir, mask_output_path, bbox_score, msrcnn_bbox): | 
					
						
						|  | if not os.path.exists(output_dir): | 
					
						
						|  | os.makedirs(output_dir) | 
					
						
						|  |  | 
					
						
						|  | global_root = os.path.join(output_dir, 'global_parsing') | 
					
						
						|  | instance_root = os.path.join(output_dir, 'instance_parsing') | 
					
						
						|  | tag_dir = os.path.join(output_dir, 'global_tag') | 
					
						
						|  |  | 
					
						
						|  | if not os.path.exists(global_root): | 
					
						
						|  | os.makedirs(global_root) | 
					
						
						|  | if not os.path.exists(instance_root): | 
					
						
						|  | os.makedirs(instance_root) | 
					
						
						|  | if not os.path.exists(tag_dir): | 
					
						
						|  | os.makedirs(tag_dir) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | palette = get_palette(256) | 
					
						
						|  |  | 
					
						
						|  | fused_output = cv2.resize(fused_output, dsize=(img_width, img_height), interpolation=cv2.INTER_LINEAR) | 
					
						
						|  | seg_pred = np.asarray(np.argmax(fused_output, axis=2), dtype=np.uint8) | 
					
						
						|  | masks = np.load(mask_output_path) | 
					
						
						|  | masks[np.where(seg_pred == 0)] = 0 | 
					
						
						|  |  | 
					
						
						|  | panoptic_seg_mask = masks | 
					
						
						|  | seg_score_list = bbox_score | 
					
						
						|  |  | 
					
						
						|  | instance_pred, class_map = get_instance(seg_pred, panoptic_seg_mask) | 
					
						
						|  | refine(instance_pred, panoptic_seg_mask, seg_pred, class_map) | 
					
						
						|  |  | 
					
						
						|  | compute_confidence(img_name, fused_output, class_map, instance_pred, instance_root, | 
					
						
						|  | panoptic_seg_mask, seg_score_list) | 
					
						
						|  |  | 
					
						
						|  | ins_seg_results = open(os.path.join(tag_dir, os.path.splitext(img_name)[0] + '.txt'), "a") | 
					
						
						|  | keep_human_id_list = list(np.unique(panoptic_seg_mask)) | 
					
						
						|  | if 0 in keep_human_id_list: | 
					
						
						|  | keep_human_id_list.remove(0) | 
					
						
						|  | for i in keep_human_id_list: | 
					
						
						|  | ins_seg_results.write('{:.6f} {} {} {} {}\n'.format(seg_score_list[i - 1], | 
					
						
						|  | int(msrcnn_bbox[i - 1][1]), int(msrcnn_bbox[i - 1][0]), | 
					
						
						|  | int(msrcnn_bbox[i - 1][3]), int(msrcnn_bbox[i - 1][2]))) | 
					
						
						|  | ins_seg_results.close() | 
					
						
						|  |  | 
					
						
						|  | output_im_global = PILImage.fromarray(seg_pred) | 
					
						
						|  | output_im_instance = PILImage.fromarray(instance_pred) | 
					
						
						|  | output_im_tag = PILImage.fromarray(panoptic_seg_mask) | 
					
						
						|  | output_im_global.putpalette(palette) | 
					
						
						|  | output_im_instance.putpalette(palette) | 
					
						
						|  | output_im_tag.putpalette(palette) | 
					
						
						|  |  | 
					
						
						|  | output_im_global.save(os.path.join(global_root, os.path.splitext(img_name)[0] + '.png')) | 
					
						
						|  | output_im_instance.save(os.path.join(instance_root, os.path.splitext(img_name)[0] + '.png')) | 
					
						
						|  | output_im_tag.save(os.path.join(tag_dir, os.path.splitext(img_name)[0] + '.png')) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def multi_process(a, args): | 
					
						
						|  | img_name = a['im_name'] | 
					
						
						|  | img_height = a['img_height'] | 
					
						
						|  | img_width = a['img_width'] | 
					
						
						|  | msrcnn_bbox = a['person_bbox'] | 
					
						
						|  | bbox_score = a['person_bbox_score'] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | global_output = np.load(os.path.join(args.global_output_dir, os.path.splitext(img_name)[0] + '.npy')) | 
					
						
						|  |  | 
					
						
						|  | msrcnn_output = patch2img_output(args.msrcnn_output_dir, img_name, img_height, img_width, msrcnn_bbox, | 
					
						
						|  | bbox_type='msrcnn', num_class=20) | 
					
						
						|  |  | 
					
						
						|  | gt_output = patch2img_output(args.gt_output_dir, img_name, img_height, img_width, msrcnn_bbox, bbox_type='msrcnn', | 
					
						
						|  | num_class=20) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | fused_output = global_output + gt_output | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | mask_output_path = os.path.join(args.mask_output_dir, os.path.splitext(img_name)[0] + '_mask.npy') | 
					
						
						|  | result_saving(fused_output, img_name, img_height, img_width, args.save_dir, mask_output_path, bbox_score, msrcnn_bbox) | 
					
						
						|  | return | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def main(args): | 
					
						
						|  | json_file = open(args.test_json_path) | 
					
						
						|  | anno = json.load(json_file)['root'] | 
					
						
						|  |  | 
					
						
						|  | results = joblib.Parallel(n_jobs=24, verbose=10, pre_dispatch="all")( | 
					
						
						|  | [joblib.delayed(multi_process)(a, args) for i, a in enumerate(anno)] | 
					
						
						|  | ) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def get_arguments(): | 
					
						
						|  | parser = argparse.ArgumentParser(description="obtain final prediction by logits fusion") | 
					
						
						|  | parser.add_argument("--test_json_path", type=str, default='./data/CIHP/cascade_152_finetune/test.json') | 
					
						
						|  | parser.add_argument("--global_output_dir", type=str, | 
					
						
						|  | default='./data/CIHP/global/global_result-cihp-resnet101/global_output') | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | parser.add_argument("--gt_output_dir", type=str, | 
					
						
						|  | default='./data/CIHP/cascade_152__finetune/gt_result-cihp-resnet101/gt_output') | 
					
						
						|  | parser.add_argument("--mask_output_dir", type=str, default='./data/CIHP/cascade_152_finetune/mask') | 
					
						
						|  | parser.add_argument("--save_dir", type=str, default='./data/CIHP/fusion_results/cihp-msrcnn_finetune') | 
					
						
						|  | return parser.parse_args() | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if __name__ == '__main__': | 
					
						
						|  | args = get_arguments() | 
					
						
						|  | main(args) | 
					
						
						|  |  |