|
|
import os |
|
|
from tqdm.auto import tqdm |
|
|
import torch |
|
|
import numpy as np |
|
|
from utils.evaluation import cal_3d_position_error, match_2d_greedy, get_matching_dict, compute_prf1, vectorize_distance, calculate_iou |
|
|
from utils.transforms import pelvis_align, root_align, unNormalize |
|
|
from utils.visualization import tensor_to_BGR, pad_img |
|
|
from utils.visualization import vis_meshes_img, vis_boxes, vis_sat, vis_scale_img, get_colors_rgb |
|
|
from utils.box_ops import box_cxcywh_to_xyxy, box_xyxy_to_cxcywh |
|
|
from utils.constants import human36_eval_joint, J24_TO_H36M, H36M_TO_MPII |
|
|
import time |
|
|
import datetime |
|
|
import scipy.io as sio |
|
|
import cv2 |
|
|
import zipfile |
|
|
import pickle |
|
|
|
|
|
|
|
|
def select_and_align(smpl_joints, smpl_verts, body_verts_ind): |
|
|
joints = smpl_joints[:24, :] |
|
|
verts = smpl_verts[body_verts_ind, :] |
|
|
assert len(verts.shape) == 2 |
|
|
verts = pelvis_align(joints, verts) |
|
|
joints = pelvis_align(joints) |
|
|
return joints, verts |
|
|
|
|
|
|
|
|
|
|
|
def evaluate_agora(model, eval_dataloader, conf_thresh, |
|
|
vis = True, vis_step = 40, results_save_path = None, |
|
|
distributed = False, accelerator = None): |
|
|
assert results_save_path is not None |
|
|
assert accelerator is not None |
|
|
num_processes = accelerator.num_processes |
|
|
|
|
|
has_kid = ('train' in eval_dataloader.dataset.split and eval_dataloader.dataset.ds_name == 'agora') |
|
|
|
|
|
os.makedirs(results_save_path,exist_ok=True) |
|
|
if vis: |
|
|
imgs_save_dir = os.path.join(results_save_path, 'imgs') |
|
|
os.makedirs(imgs_save_dir, exist_ok = True) |
|
|
|
|
|
step = 0 |
|
|
total_miss_count = 0 |
|
|
total_count = 0 |
|
|
total_fp = 0 |
|
|
mve, mpjpe = [0.], [0.] |
|
|
|
|
|
if has_kid: |
|
|
kid_total_miss_count = 0 |
|
|
kid_total_count = 0 |
|
|
kid_mve, kid_mpjpe = [0.], [0.] |
|
|
|
|
|
cur_device = next(model.parameters()).device |
|
|
smpl_layer = model.human_model |
|
|
body_verts_ind = smpl_layer.body_vertex_idx |
|
|
|
|
|
progress_bar = tqdm(total=len(eval_dataloader), disable=not accelerator.is_local_main_process) |
|
|
progress_bar.set_description('evaluate') |
|
|
for itr, (samples, targets) in enumerate(eval_dataloader): |
|
|
samples=[sample.to(device = cur_device, non_blocking = True) for sample in samples] |
|
|
with torch.no_grad(): |
|
|
outputs = model(samples, targets) |
|
|
bs = len(targets) |
|
|
for idx in range(bs): |
|
|
|
|
|
gt_j2ds = targets[idx]['j2ds'].cpu().numpy()[:,:24,:] |
|
|
gt_j3ds = targets[idx]['j3ds'].cpu().numpy()[:,:24,:] |
|
|
gt_verts = targets[idx]['verts'].cpu().numpy() |
|
|
|
|
|
|
|
|
select_queries_idx = torch.where(outputs['pred_confs'][idx] > conf_thresh)[0] |
|
|
pred_j2ds = outputs['pred_j2ds'][idx][select_queries_idx].detach().cpu().numpy()[:,:24,:] |
|
|
pred_j3ds = outputs['pred_j3ds'][idx][select_queries_idx].detach().cpu().numpy()[:,:24,:] |
|
|
pred_verts = outputs['pred_verts'][idx][select_queries_idx].detach().cpu().numpy() |
|
|
|
|
|
|
|
|
matched_verts_idx = [] |
|
|
assert len(gt_j2ds.shape) == 3 and len(pred_j2ds.shape) == 3 |
|
|
|
|
|
greedy_match = match_2d_greedy(pred_j2ds, gt_j2ds) |
|
|
matchDict, falsePositive_count = get_matching_dict(greedy_match) |
|
|
|
|
|
|
|
|
gt_verts_list, pred_verts_list, gt_joints_list, pred_joints_list = [], [], [], [] |
|
|
gtIdxs = np.arange(len(gt_j3ds)) |
|
|
miss_flag = [] |
|
|
for gtIdx in gtIdxs: |
|
|
gt_verts_list.append(gt_verts[gtIdx]) |
|
|
gt_joints_list.append(gt_j3ds[gtIdx]) |
|
|
if matchDict[str(gtIdx)] == 'miss' or matchDict[str( |
|
|
gtIdx)] == 'invalid': |
|
|
miss_flag.append(1) |
|
|
pred_verts_list.append([]) |
|
|
pred_joints_list.append([]) |
|
|
else: |
|
|
miss_flag.append(0) |
|
|
pred_joints_list.append(pred_j3ds[matchDict[str(gtIdx)]]) |
|
|
pred_verts_list.append(pred_verts[matchDict[str(gtIdx)]]) |
|
|
matched_verts_idx.append(matchDict[str(gtIdx)]) |
|
|
|
|
|
if has_kid: |
|
|
gt_kid_list = targets[idx]['kid'] |
|
|
|
|
|
|
|
|
for i, (gt3d, pred) in enumerate(zip(gt_joints_list, pred_joints_list)): |
|
|
total_count += 1 |
|
|
if has_kid and gt_kid_list[i]: |
|
|
kid_total_count += 1 |
|
|
|
|
|
|
|
|
if miss_flag[i] == 1: |
|
|
total_miss_count += 1 |
|
|
if has_kid and gt_kid_list[i]: |
|
|
kid_total_miss_count += 1 |
|
|
continue |
|
|
|
|
|
gt3d = gt3d.reshape(-1, 3) |
|
|
pred3d = pred.reshape(-1, 3) |
|
|
gt3d_verts = gt_verts_list[i].reshape(-1, 3) |
|
|
pred3d_verts = pred_verts_list[i].reshape(-1, 3) |
|
|
|
|
|
gt3d, gt3d_verts = select_and_align(gt3d, gt3d_verts, body_verts_ind) |
|
|
pred3d, pred3d_verts = select_and_align(pred3d, pred3d_verts, body_verts_ind) |
|
|
|
|
|
|
|
|
error_j, pa_error_j = cal_3d_position_error(pred3d, gt3d) |
|
|
mpjpe.append(error_j) |
|
|
if has_kid and gt_kid_list[i]: |
|
|
kid_mpjpe.append(error_j) |
|
|
|
|
|
error_v,pa_error_v = cal_3d_position_error(pred3d_verts, gt3d_verts) |
|
|
mve.append(error_v) |
|
|
if has_kid and gt_kid_list[i]: |
|
|
kid_mve.append(error_v) |
|
|
|
|
|
|
|
|
|
|
|
step += 1 |
|
|
total_fp += falsePositive_count |
|
|
|
|
|
img_idx = step + accelerator.process_index*len(eval_dataloader)*bs |
|
|
|
|
|
if vis and (img_idx%vis_step == 0): |
|
|
img_name = targets[idx]['img_path'].split('/')[-1].split('.')[0] |
|
|
ori_img = tensor_to_BGR(unNormalize(samples[idx]).cpu()) |
|
|
|
|
|
|
|
|
colors = [(1.0, 1.0, 0.9)] * len(gt_verts) |
|
|
gt_mesh_img = vis_meshes_img(img = ori_img.copy(), |
|
|
verts = gt_verts, |
|
|
smpl_faces = smpl_layer.faces, |
|
|
cam_intrinsics = targets[idx]['cam_intrinsics'].reshape(3,3).detach().cpu(), |
|
|
colors = colors) |
|
|
|
|
|
colors = [(1.0, 0.6, 0.6)] * len(pred_verts) |
|
|
for i in matched_verts_idx: |
|
|
colors[i] = (0.7, 1.0, 0.4) |
|
|
|
|
|
|
|
|
pred_mesh_img = vis_meshes_img(img = ori_img.copy(), |
|
|
verts = pred_verts, |
|
|
smpl_faces = smpl_layer.faces, |
|
|
cam_intrinsics = outputs['pred_intrinsics'][idx].reshape(3,3).detach().cpu(), |
|
|
colors = colors, |
|
|
) |
|
|
|
|
|
|
|
|
if 'enc_outputs' not in outputs: |
|
|
pred_scale_img = np.zeros_like(pred_mesh_img) |
|
|
else: |
|
|
enc_out = outputs['enc_outputs'] |
|
|
h, w = enc_out['hw'][idx] |
|
|
flatten_map = enc_out['scale_map'].split(enc_out['lens'])[idx].detach().cpu() |
|
|
|
|
|
ys = enc_out['pos_y'].split(enc_out['lens'])[idx] |
|
|
xs = enc_out['pos_x'].split(enc_out['lens'])[idx] |
|
|
scale_map = torch.zeros((h,w,2)) |
|
|
scale_map[ys,xs] = flatten_map |
|
|
|
|
|
pred_scale_img = vis_scale_img(img = ori_img.copy(), |
|
|
scale_map = scale_map, |
|
|
conf_thresh = model.sat_cfg['conf_thresh'], |
|
|
patch_size=28) |
|
|
|
|
|
pred_boxes = outputs['pred_boxes'][idx][select_queries_idx].detach().cpu() |
|
|
pred_boxes = box_cxcywh_to_xyxy(pred_boxes) * model.input_size |
|
|
pred_box_img = vis_boxes(ori_img.copy(), pred_boxes, color = (255,0,255)) |
|
|
|
|
|
|
|
|
sat_img = vis_sat(ori_img.copy(), |
|
|
input_size = model.input_size, |
|
|
patch_size = 14, |
|
|
sat_dict = outputs['sat'], |
|
|
bid = idx) |
|
|
|
|
|
ori_img = pad_img(ori_img, model.input_size) |
|
|
|
|
|
full_img = np.vstack([np.hstack([ori_img, sat_img]), |
|
|
np.hstack([pred_scale_img, pred_box_img]), |
|
|
np.hstack([gt_mesh_img, pred_mesh_img])]) |
|
|
|
|
|
cv2.imwrite(os.path.join(imgs_save_dir, f'{img_idx}_{img_name}.png'), full_img) |
|
|
|
|
|
progress_bar.update(1) |
|
|
|
|
|
if distributed: |
|
|
mve = accelerator.gather_for_metrics(mve) |
|
|
mpjpe = accelerator.gather_for_metrics(mpjpe) |
|
|
|
|
|
|
|
|
total_miss_count = sum(accelerator.gather_for_metrics([total_miss_count])) |
|
|
total_count = sum(accelerator.gather_for_metrics([total_count])) |
|
|
total_fp = sum(accelerator.gather_for_metrics([total_fp])) |
|
|
|
|
|
if has_kid: |
|
|
kid_mve = accelerator.gather_for_metrics(kid_mve) |
|
|
kid_mpjpe = accelerator.gather_for_metrics(kid_mpjpe) |
|
|
kid_total_miss_count = sum(accelerator.gather_for_metrics([kid_total_miss_count])) |
|
|
kid_total_count = sum(accelerator.gather_for_metrics([kid_total_count])) |
|
|
|
|
|
if len(mpjpe) <= num_processes: |
|
|
return "Failed to evaluate. Keep training!" |
|
|
if has_kid and len(kid_mpjpe) <= num_processes: |
|
|
return "Failed to evaluate. Keep training!" |
|
|
|
|
|
precision, recall, f1 = compute_prf1(total_count,total_miss_count,total_fp) |
|
|
error_dict = {} |
|
|
error_dict['precision'] = precision |
|
|
error_dict['recall'] = recall |
|
|
error_dict['f1'] = f1 |
|
|
|
|
|
error_dict['MPJPE'] = round(sum(mpjpe)/(len(mpjpe)-num_processes), 1) |
|
|
error_dict['NMJE'] = round(error_dict['MPJPE'] / (f1), 1) |
|
|
error_dict['MVE'] = round(sum(mve)/(len(mve)-num_processes), 1) |
|
|
error_dict['NMVE'] = round(error_dict['MVE'] / (f1), 1) |
|
|
|
|
|
if has_kid: |
|
|
kid_precision, kid_recall, kid_f1 = compute_prf1(kid_total_count,kid_total_miss_count,total_fp) |
|
|
error_dict['kid_precision'] = kid_precision |
|
|
error_dict['kid_recall'] = kid_recall |
|
|
error_dict['kid_f1'] = kid_f1 |
|
|
|
|
|
error_dict['kid-MPJPE'] = round(sum(kid_mpjpe)/(len(kid_mpjpe)-num_processes), 1) |
|
|
error_dict['kid-NMJE'] = round(error_dict['kid-MPJPE'] / (kid_f1), 1) |
|
|
error_dict['kid-MVE'] = round(sum(kid_mve)/(len(kid_mve)-num_processes), 1) |
|
|
error_dict['kid-NMVE'] = round(error_dict['kid-MVE'] / (kid_f1), 1) |
|
|
|
|
|
|
|
|
if accelerator.is_main_process: |
|
|
with open(os.path.join(results_save_path,'results.txt'),'w') as f: |
|
|
for k,v in error_dict.items(): |
|
|
f.write(f'{k}: {v}\n') |
|
|
|
|
|
return error_dict |
|
|
|
|
|
|
|
|
def test_agora(model, eval_dataloader, conf_thresh, |
|
|
vis = True, vis_step = 400, results_save_path = None, |
|
|
distributed = False, accelerator = None): |
|
|
assert results_save_path is not None |
|
|
assert accelerator is not None |
|
|
|
|
|
os.makedirs(os.path.join(results_save_path,'predictions'),exist_ok=True) |
|
|
if vis: |
|
|
imgs_save_dir = os.path.join(results_save_path, 'imgs') |
|
|
os.makedirs(imgs_save_dir, exist_ok = True) |
|
|
step = 0 |
|
|
cur_device = next(model.parameters()).device |
|
|
smpl_layer = model.human_model |
|
|
|
|
|
progress_bar = tqdm(total=len(eval_dataloader), disable=not accelerator.is_local_main_process) |
|
|
progress_bar.set_description('testing') |
|
|
for itr, (samples, targets) in enumerate(eval_dataloader): |
|
|
samples=[sample.to(device = cur_device, non_blocking = True) for sample in samples] |
|
|
with torch.no_grad(): |
|
|
outputs = model(samples, targets) |
|
|
bs = len(targets) |
|
|
for idx in range(bs): |
|
|
|
|
|
img_name = targets[idx]['img_name'].split('.')[0] |
|
|
|
|
|
select_queries_idx = torch.where(outputs['pred_confs'][idx] > conf_thresh)[0] |
|
|
pred_j2ds = np.array(outputs['pred_j2ds'][idx][select_queries_idx].detach().to('cpu'))[:,:24,:]*(3840/model.input_size) |
|
|
pred_j3ds = np.array(outputs['pred_j3ds'][idx][select_queries_idx].detach().to('cpu'))[:,:24,:] |
|
|
pred_verts = np.array(outputs['pred_verts'][idx][select_queries_idx].detach().to('cpu')) |
|
|
pred_poses = np.array(outputs['pred_poses'][idx][select_queries_idx].detach().to('cpu')) |
|
|
pred_betas = np.array(outputs['pred_betas'][idx][select_queries_idx].detach().to('cpu')) |
|
|
|
|
|
|
|
|
step+=1 |
|
|
img_idx = step + accelerator.process_index*len(eval_dataloader)*bs |
|
|
if vis and (img_idx%vis_step == 0): |
|
|
ori_img = tensor_to_BGR(unNormalize(samples[idx]).cpu()) |
|
|
ori_img = pad_img(ori_img, model.input_size) |
|
|
|
|
|
sat_img = vis_sat(ori_img.copy(), |
|
|
input_size = model.input_size, |
|
|
patch_size = 14, |
|
|
sat_dict = outputs['sat'], |
|
|
bid = idx) |
|
|
|
|
|
colors = get_colors_rgb(len(pred_verts)) |
|
|
mesh_img = vis_meshes_img(img = ori_img.copy(), |
|
|
verts = pred_verts, |
|
|
smpl_faces = smpl_layer.faces, |
|
|
colors = colors, |
|
|
cam_intrinsics = outputs['pred_intrinsics'][idx].detach().cpu()) |
|
|
|
|
|
if 'enc_outputs' not in outputs: |
|
|
pred_scale_img = np.zeros_like(ori_img) |
|
|
else: |
|
|
enc_out = outputs['enc_outputs'] |
|
|
h, w = enc_out['hw'][idx] |
|
|
flatten_map = enc_out['scale_map'].split(enc_out['lens'])[idx].detach().cpu() |
|
|
|
|
|
ys = enc_out['pos_y'].split(enc_out['lens'])[idx] |
|
|
xs = enc_out['pos_x'].split(enc_out['lens'])[idx] |
|
|
scale_map = torch.zeros((h,w,2)) |
|
|
scale_map[ys,xs] = flatten_map |
|
|
pred_scale_img = vis_scale_img(img = ori_img.copy(), |
|
|
scale_map = scale_map, |
|
|
conf_thresh = model.sat_cfg['conf_thresh'], |
|
|
patch_size=28) |
|
|
|
|
|
full_img = np.vstack([np.hstack([ori_img, mesh_img]), |
|
|
np.hstack([pred_scale_img, sat_img])]) |
|
|
cv2.imwrite(os.path.join(imgs_save_dir, f'{img_idx}_{img_name}.jpg'), full_img) |
|
|
|
|
|
|
|
|
|
|
|
for pnum in range(len(pred_j2ds)): |
|
|
smpl_dict = {} |
|
|
|
|
|
smpl_dict['joints'] = pred_j2ds[pnum].reshape(24,2) |
|
|
smpl_dict['params'] = {'transl': np.zeros((1,3)), |
|
|
'betas': pred_betas[pnum].reshape(1,10), |
|
|
'global_orient': pred_poses[pnum][:3].reshape(1,1,3), |
|
|
'body_pose': pred_poses[pnum][3:].reshape(1,23,3)} |
|
|
|
|
|
|
|
|
with open(os.path.join(results_save_path,'predictions',f'{img_name}_personId_{pnum}.pkl'), 'wb') as f: |
|
|
pickle.dump(smpl_dict, f) |
|
|
|
|
|
progress_bar.update(1) |
|
|
|
|
|
accelerator.print('Packing...') |
|
|
|
|
|
folder_path = os.path.join(results_save_path,'predictions') |
|
|
now = datetime.datetime.now() |
|
|
timestamp = now.strftime("%Y%m%d_%H%M%S") |
|
|
output_path = os.path.join(results_save_path,f'pred_{timestamp}.zip') |
|
|
with zipfile.ZipFile(output_path, 'w', zipfile.ZIP_DEFLATED) as zipf: |
|
|
for root, dirs, files in os.walk(folder_path): |
|
|
for file in files: |
|
|
file_path = os.path.join(root, file) |
|
|
arcname = os.path.relpath(file_path, os.path.dirname(folder_path)) |
|
|
zipf.write(file_path, arcname) |
|
|
|
|
|
|
|
|
return 'Results saved at: ' + os.path.join(results_save_path,'predictions') |
|
|
|
|
|
|