Spaces:
Running
Running
import numpy as np | |
import matplotlib.pyplot as plt | |
import matplotlib.image as mpimg | |
import cv2 | |
import skimage | |
import torch | |
from PIL import Image | |
joints = [ | |
'left ankle', | |
'left knee', | |
'left hip', | |
'right hip', | |
'right knee', | |
'right ankle', | |
'belly', | |
'chest', | |
'neck', | |
'head', | |
'left wrist', | |
'left elbow', | |
'left shoulder', | |
'right shoulder', | |
'right elbow', | |
'right wrist' | |
] | |
def generate_heatmap(heatmap, pt, sigma=(33, 33), sigma_valu=7): | |
''' | |
:param heatmap: should be a np zeros array with shape (H,W) (only i channel), not (H,W,1) | |
:param pt: point coords, np array | |
:param sigma: should be a tuple with odd values (obsolete) | |
:param sigma_valu: vaalue for gaussian blur | |
:return: a np array of one joint heatmap with shape (H,W) | |
This function is obsolete, use 'generate_heatmaps()' instead. | |
''' | |
heatmap[int(pt[1])][int(pt[0])] = 1 | |
# heatmap = cv2.GaussianBlur(heatmap, sigma, 0) #(H,W,1) -> (H,W) | |
heatmap = skimage.filters.gaussian( | |
heatmap, sigma=sigma_valu) # (H,W,1) -> (H,W) | |
am = np.amax(heatmap) | |
heatmap = heatmap/am | |
return heatmap | |
def generate_heatmaps(img, pts, sigma=(33, 33), sigma_valu=7): | |
''' | |
:param img: np arrray img, (H,W,C) | |
:param pts: joint points coords, np array, same resolu as img | |
:param sigma: should be a tuple with odd values (obsolete) | |
:param sigma_valu: vaalue for gaussian blur | |
:return: np array heatmaps, (H,W,num_pts) | |
''' | |
H, W = img.shape[0], img.shape[1] | |
num_pts = pts.shape[0] | |
heatmaps = np.zeros((H, W, num_pts)) | |
for i, pt in enumerate(pts): | |
# Filter unavailable heatmaps | |
if pt[0] == 0 and pt[1] == 0: | |
continue | |
# Filter some points out of the image | |
if pt[0] >= W: | |
pt[0] = W-1 | |
if pt[1] >= H: | |
pt[1] = H-1 | |
heatmap = heatmaps[:, :, i] | |
heatmap[int(pt[1])][int(pt[0])] = 1 | |
# heatmap = cv2.GaussianBlur(heatmap, sigma, 0) #(H,W,1) -> (H,W) | |
heatmap = skimage.filters.gaussian( | |
heatmap, sigma=sigma_valu) # (H,W,1) -> (H,W) | |
am = np.amax(heatmap) | |
heatmap = heatmap / am | |
heatmaps[:, :, i] = heatmap | |
return heatmaps | |
def load_image(path_image): | |
img = mpimg.imread(path_image) | |
# Return a np array (H,W,C) | |
return img | |
def crop(img, ele_anno, use_randscale=True, use_randflipLR=False, use_randcolor=False): | |
''' | |
:param img: np array of the origin image, (H,W,C) | |
:param ele_anno: one element of json annotation | |
:return: img_crop, ary_pts_crop, c_crop after cropping | |
''' | |
H, W = img.shape[0], img.shape[1] | |
s = ele_anno['scale_provided'] | |
c = ele_anno['objpos'] | |
# Adjust center and scale | |
if c[0] != -1: | |
c[1] = c[1] + 15 * s | |
s = s * 1.25 | |
ary_pts = np.array(ele_anno['joint_self']) # (16, 3) | |
ary_pts_temp = ary_pts[np.any(ary_pts != [0, 0, 0], axis=1)] | |
if use_randscale: | |
scale_rand = np.random.uniform(low=1.0, high=3.0) | |
else: | |
scale_rand = 1 | |
W_min = max(np.amin(ary_pts_temp, axis=0)[0] - s * 15 * scale_rand, 0) | |
H_min = max(np.amin(ary_pts_temp, axis=0)[1] - s * 15 * scale_rand, 0) | |
W_max = min(np.amax(ary_pts_temp, axis=0)[0] + s * 15 * scale_rand, W) | |
H_max = min(np.amax(ary_pts_temp, axis=0)[1] + s * 15 * scale_rand, H) | |
W_len = W_max - W_min | |
H_len = H_max - H_min | |
window_len = max(H_len, W_len) | |
pad_updown = (window_len - H_len)/2 | |
pad_leftright = (window_len - W_len)/2 | |
# Calculate 4 corner position | |
W_low = max((W_min - pad_leftright), 0) | |
W_high = min((W_max + pad_leftright), W) | |
H_low = max((H_min - pad_updown), 0) | |
H_high = min((H_max + pad_updown), H) | |
# Update joint points and center | |
ary_pts_crop = np.where( | |
ary_pts == [0, 0, 0], ary_pts, ary_pts - np.array([W_low, H_low, 0])) | |
c_crop = c - np.array([W_low, H_low]) | |
img_crop = img[int(H_low):int(H_high), int(W_low):int(W_high), :] | |
# Pad when H, W different | |
H_new, W_new = img_crop.shape[0], img_crop.shape[1] | |
window_len_new = max(H_new, W_new) | |
pad_updown_new = int((window_len_new - H_new)/2) | |
pad_leftright_new = int((window_len_new - W_new)/2) | |
# ReUpdate joint points and center (because of the padding) | |
ary_pts_crop = np.where(ary_pts_crop == [ | |
0, 0, 0], ary_pts_crop, ary_pts_crop + np.array([pad_leftright_new, pad_updown_new, 0])) | |
c_crop = c_crop + np.array([pad_leftright_new, pad_updown_new]) | |
img_crop = cv2.copyMakeBorder(img_crop, pad_updown_new, pad_updown_new, | |
pad_leftright_new, pad_leftright_new, cv2.BORDER_CONSTANT, value=0) | |
# change dtype and num scale | |
img_crop = img_crop / 255. | |
img_crop = img_crop.astype(np.float64) | |
if use_randflipLR: | |
flip = np.random.random() > 0.5 | |
# print('rand_flipLR', flip) | |
if flip: | |
# (H,W,C) | |
img_crop = np.flip(img_crop, 1) | |
# Calculate flip pts, remember to filter [0,0] which is no available heatmap | |
ary_pts_crop = np.where(ary_pts_crop == [0, 0, 0], ary_pts_crop, | |
[window_len_new, 0, 0] + ary_pts_crop * [-1, 1, 0]) | |
c_crop = [window_len_new, 0] + c_crop * [-1, 1] | |
# Rearrange pts | |
ary_pts_crop = np.concatenate( | |
(ary_pts_crop[5::-1], ary_pts_crop[6:10], ary_pts_crop[15:9:-1])) | |
if use_randcolor: | |
randcolor = np.random.random() > 0.5 | |
# print('rand_color', randcolor) | |
if randcolor: | |
img_crop[..., | |
0] *= np.clip(np.random.uniform(low=0.8, high=1.2), 0., 1.) | |
img_crop[..., | |
1] *= np.clip(np.random.uniform(low=0.8, high=1.2), 0., 1.) | |
img_crop[..., | |
2] *= np.clip(np.random.uniform(low=0.8, high=1.2), 0., 1.) | |
return img_crop, ary_pts_crop, c_crop | |
def change_resolu(img, pts, c, resolu_out=(256, 256)): | |
''' | |
:param img: np array of the origin image | |
:param pts: joint points np array corresponding to the image, same resolu as img | |
:param c: center | |
:param resolu_out: a list or tuple | |
:return: img_out, pts_out, c_out under resolu_out | |
''' | |
H_in = img.shape[0] | |
W_in = img.shape[1] | |
H_out = resolu_out[0] | |
W_out = resolu_out[1] | |
H_scale = H_in/H_out | |
W_scale = W_in/W_out | |
pts_out = pts/np.array([W_scale, H_scale, 1]) | |
c_out = c/np.array([W_scale, H_scale]) | |
img_out = skimage.transform.resize(img, tuple(resolu_out)) | |
return img_out, pts_out, c_out | |
def heatmaps_to_coords(heatmaps, resolu_out=[64, 64], prob_threshold=0.2): | |
''' | |
:param heatmaps: tensor with shape (64,64,16) | |
:param resolu_out: output resolution list | |
:return coord_joints: np array, shape (16,2) | |
''' | |
num_joints = heatmaps.shape[2] | |
# Resize | |
heatmaps = skimage.transform.resize(heatmaps, tuple(resolu_out)) | |
coord_joints = np.zeros((num_joints, 3)) | |
for i in range(num_joints): | |
heatmap = heatmaps[..., i] | |
max = np.max(heatmap) | |
# Only keep points larger than a threshold | |
if max >= prob_threshold: | |
idx = np.where(heatmap == max) | |
H = idx[0][0] | |
W = idx[1][0] | |
else: | |
H = 0 | |
W = 0 | |
coord_joints[i] = [W, H, max] | |
return coord_joints | |
def show_heatmaps(img, heatmaps, c=np.zeros((2)), num_fig=1): | |
''' | |
:param img: np array (H,W,3) | |
:param heatmaps: np array (H,W,num_pts) | |
:param c: center, np array (2,) | |
''' | |
H, W = img.shape[0], img.shape[1] | |
if heatmaps.shape[0] != H: | |
heatmaps = skimage.transform.resize(heatmaps, (H, W)) | |
plt.figure(num_fig) | |
for i in range(heatmaps.shape[2] + 1): | |
plt.subplot(4, 5, i + 1) | |
if i == 0: | |
plt.title('Origin') | |
else: | |
plt.title(joints[i-1]) | |
if i == 0: | |
plt.imshow(img) | |
else: | |
plt.imshow(heatmaps[:, :, i - 1]) | |
plt.axis('off') | |
plt.subplot(4, 5, 20) | |
plt.axis('off') | |
plt.show() | |
def heatmap2rgb(heatmap): | |
""" | |
: heatmap: (h,w) | |
""" | |
heatmap = heatmap.detach().cpu().numpy() | |
# plt.figure(figsize=(1,1)) | |
# plt.axis('off') | |
# plt.imshow(heatmap) | |
# plt.savefig('tmp/tmp.jpg', bbox_inches='tight', pad_inches=0, dpi=70) | |
# plt.close() | |
# plt.clf() | |
# img = Image.open('tmp/tmp.jpg') | |
cm = plt.get_cmap('jet') | |
normed_data = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap + 1e-8)) | |
mapped_data = cm(normed_data) | |
# (h,w,c) | |
# img = np.array(img) | |
img = np.array(mapped_data) | |
img = img[:,:,:3] | |
img = torch.tensor(img).permute(2, 0, 1) | |
return img | |
def heatmaps2rgb(heatmaps): | |
""" | |
: heatmaps: (b,h,w) | |
""" | |
out_imgs = [] | |
for heatmap in heatmaps: | |
out_imgs.append(heatmap2rgb(heatmap)) | |
return torch.stack(out_imgs) | |
# def draw_joints(img, pts): | |
# scores = pts[:,2] | |
# pts = np.array(pts).astype(int) | |
# for i in range(pts.shape[0]): | |
# if pts[i, 0] != 0 and pts[i, 1] != 0: | |
# img = cv2.circle(img, (pts[i, 0], pts[i, 1]), radius=3, | |
# color=(255, 0, 0), thickness=-1) | |
# print('img',img.max(),img.min()) | |
# # img = cv2.putText(img, f'{joints[i]}: {scores[i]:.2f}', ( | |
# # pts[i, 0]+5, pts[i, 1]-5), cv2.FONT_HERSHEY_SIMPLEX, .25, (255, 0, 0)) | |
# # Left arm | |
# for i in range(10, 13-1): | |
# if pts[i, 0] != 0 and pts[i, 1] != 0 and pts[i+1, 0] != 0 and pts[i+1, 1] != 0: | |
# img = cv2.line(img, (pts[i, 0], pts[i, 1]), (pts[i+1, 0], | |
# pts[i+1, 1]), color=(255, 0, 0), thickness=1) | |
# # Right arm | |
# for i in range(13, 16-1): | |
# if pts[i, 0] != 0 and pts[i, 1] != 0 and pts[i+1, 0] != 0 and pts[i+1, 1] != 0: | |
# img = cv2.line(img, (pts[i, 0], pts[i, 1]), (pts[i+1, 0], | |
# pts[i+1, 1]), color=(255, 0, 0), thickness=1) | |
# # Left leg | |
# for i in range(0, 3-1): | |
# if pts[i, 0] != 0 and pts[i, 1] != 0 and pts[i+1, 0] != 0 and pts[i+1, 1] != 0: | |
# img = cv2.line(img, (pts[i, 0], pts[i, 1]), (pts[i+1, 0], | |
# pts[i+1, 1]), color=(255, 0, 0), thickness=1) | |
# # Right leg | |
# for i in range(3, 6-1): | |
# if pts[i, 0] != 0 and pts[i, 1] != 0 and pts[i+1, 0] != 0 and pts[i+1, 1] != 0: | |
# img = cv2.line(img, (pts[i, 0], pts[i, 1]), (pts[i+1, 0], | |
# pts[i+1, 1]), color=(255, 0, 0), thickness=1) | |
# # Body | |
# for i in range(6, 10-1): | |
# if pts[i, 0] != 0 and pts[i, 1] != 0 and pts[i+1, 0] != 0 and pts[i+1, 1] != 0: | |
# img = cv2.line(img, (pts[i, 0], pts[i, 1]), (pts[i+1, 0], | |
# pts[i+1, 1]), color=(255, 0, 0), thickness=1) | |
# if pts[2, 0] != 0 and pts[2, 1] != 0 and pts[3, 0] != 0 and pts[3, 1] != 0: | |
# img = cv2.line(img, (pts[2, 0], pts[2, 1]), (pts[2+1, 0], | |
# pts[2+1, 1]), color=(255, 0, 0), thickness=1) | |
# if pts[12, 0] != 0 and pts[12, 1] != 0 and pts[13, 0] != 0 and pts[13, 1] != 0: | |
# img = cv2.line(img, (pts[12, 0], pts[12, 1]), (pts[12+1, 0], | |
# pts[12+1, 1]), color=(255, 0, 0), thickness=1) | |
# return img | |
def draw_joints(img, pts): | |
# Convert the image to the range [0, 255] for visualization | |
img_visualization = (img).astype(np.uint8) | |
# Draw lines for the body parts | |
for i in range(10, 13 - 1): | |
draw_line(img_visualization, pts[i], pts[i + 1]) | |
for i in range(13, 16 - 1): | |
draw_line(img_visualization, pts[i], pts[i + 1]) | |
for i in range(0, 3 - 1): | |
draw_line(img_visualization, pts[i], pts[i + 1]) | |
for i in range(3, 6 - 1): | |
draw_line(img_visualization, pts[i], pts[i + 1]) | |
for i in range(6, 10 - 1): | |
draw_line(img_visualization, pts[i], pts[i + 1]) | |
draw_line(img_visualization, pts[2], pts[3]) | |
draw_line(img_visualization, pts[12], pts[13]) | |
return img_visualization / 255.0 | |
def draw_line(img, pt1, pt2): | |
if pt1[0] != 0 and pt1[1] != 0 and pt2[0] != 0 and pt2[1] != 0: | |
cv2.line(img, (int(pt1[0]), int(pt1[1])), (int(pt2[0]), int(pt2[1])), color=(57, 255, 20), thickness=2) | |