Spaces:
Running
Running
feat: upload project
Browse files- app.py +43 -4
- examples/000000000016.jpg +0 -0
- examples/000000000552.jpg +0 -0
- models/hr_net.py +163 -0
- models/modules/__init__.py +0 -0
- models/modules/blocks/__init__.py +0 -0
- models/modules/blocks/basic_block.py +37 -0
- models/modules/blocks/bottleneck.py +57 -0
- models/modules/stage_module.py +104 -0
- models/modules/stem.py +29 -0
- requirements.txt +17 -0
- tool_utils.py +372 -0
app.py
CHANGED
|
@@ -1,7 +1,46 @@
|
|
| 1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
import time
|
| 3 |
+
import numpy
|
| 4 |
+
import os
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
import torch
|
| 8 |
+
import skimage
|
| 9 |
+
from models.hr_net import hr_w32
|
| 10 |
+
from tool_utils import heatmaps_to_coords,draw_joints
|
| 11 |
|
| 12 |
+
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
|
| 13 |
+
#Create example list from 'examples/'directory
|
| 14 |
+
example_list=[["./examples/"+example] for example in os.listdir("examples")]
|
| 15 |
|
| 16 |
+
def predict(numpy_img):
|
| 17 |
+
#resize the numpy_image size to (256,256)
|
| 18 |
+
img_np=skimage.transform.resize(numpy_img,[256,256])
|
| 19 |
+
#convert numpy_image to tensor
|
| 20 |
+
img=torch.from_numpy(img_np).permute(2,0,1).unsqueeze(0).float().to(device)
|
| 21 |
+
#choose model class hr_w32
|
| 22 |
+
model=hr_w32().to(device)
|
| 23 |
+
#load weights of model
|
| 24 |
+
model.load_state_dict(torch.load('./weights/HRNet_epoch20_loss0.000474.pth')['model'])
|
| 25 |
+
# #set model to pred state
|
| 26 |
+
model.eval()
|
| 27 |
+
# #predict the heatmaps of joints
|
| 28 |
+
start_time=time.time()
|
| 29 |
+
heatmaps_pred=model(img)
|
| 30 |
+
heatmaps_pred=heatmaps_pred.double()
|
| 31 |
+
# #convert output to numpy
|
| 32 |
+
heatmaps_pred_np=heatmaps_pred.squeeze(0).permute(1,2,0).detach().cpu().numpy()
|
| 33 |
+
# #heatmaps to joints location
|
| 34 |
+
coord_joints=heatmaps_to_coords(heatmaps_pred_np,resolu_out=[256,256],prob_threshold=0.1)
|
| 35 |
+
inference_time=time.time()-start_time
|
| 36 |
+
inference_time_text="model inference time:{:.4f}s".format(inference_time)
|
| 37 |
+
# #draw coords on image_np
|
| 38 |
+
img_rgb=draw_joints(img_np,coord_joints)
|
| 39 |
+
return img_rgb,inference_time_text
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
demo=gr.Interface(fn=predict, inputs=gr.Image(),outputs=[gr.Image(type='numpy',width=256,height=256),"text"],examples=example_list)
|
| 44 |
+
|
| 45 |
+
if __name__=="__main__":
|
| 46 |
+
demo.launch(show_api=False)
|
examples/000000000016.jpg
ADDED
|
examples/000000000552.jpg
ADDED
|
models/hr_net.py
ADDED
|
@@ -0,0 +1,163 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
from models.modules.blocks.bottleneck import Bottleneck
|
| 5 |
+
from models.modules.stage_module import StageModule
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def weights_init(m):
|
| 9 |
+
if isinstance(m, nn.Conv2d):
|
| 10 |
+
nn.init.normal_(m.weight, std=.01)
|
| 11 |
+
if m.bias is not None:
|
| 12 |
+
nn.init.constant_(m.bias, 0)
|
| 13 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 14 |
+
nn.init.constant_(m.weight, 1)
|
| 15 |
+
if m.bias is not None:
|
| 16 |
+
nn.init.constant_(m.bias, 0)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class HRNet(nn.Module):
|
| 20 |
+
|
| 21 |
+
def __init__(self, c=48, nof_joints=16, bn_momentum=.1):
|
| 22 |
+
super(HRNet, self).__init__()
|
| 23 |
+
|
| 24 |
+
# (b,3,y,x) -> (b,64,y,x)
|
| 25 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
|
| 26 |
+
stride=2, padding=1, bias=False)
|
| 27 |
+
self.bn1 = nn.BatchNorm2d(64, momentum=bn_momentum)
|
| 28 |
+
self.conv2 = nn.Conv2d(64, 64, kernel_size=3,
|
| 29 |
+
stride=2, padding=1, bias=False)
|
| 30 |
+
self.bn2 = nn.BatchNorm2d(64, momentum=bn_momentum)
|
| 31 |
+
self.relu = nn.ReLU(inplace=True)
|
| 32 |
+
|
| 33 |
+
# (b,64,y,x) -> (b,256,y,x)
|
| 34 |
+
downsample = nn.Sequential(
|
| 35 |
+
nn.Conv2d(64, 256, kernel_size=1, stride=1, bias=False),
|
| 36 |
+
nn.BatchNorm2d(256),
|
| 37 |
+
)
|
| 38 |
+
self.layer1 = nn.Sequential(
|
| 39 |
+
Bottleneck(64, 64, downsample=downsample),
|
| 40 |
+
Bottleneck(256, 64),
|
| 41 |
+
Bottleneck(256, 64),
|
| 42 |
+
Bottleneck(256, 64),
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# (b,256,y,x) ---+---> (b,c,y,x)
|
| 46 |
+
# +---> (b,c*2,y/2,x/2)
|
| 47 |
+
self.transition1 = nn.ModuleList([
|
| 48 |
+
nn.Sequential(
|
| 49 |
+
nn.Conv2d(256, c, kernel_size=3,
|
| 50 |
+
stride=1, padding=1, bias=False),
|
| 51 |
+
nn.BatchNorm2d(c),
|
| 52 |
+
nn.ReLU(inplace=True),
|
| 53 |
+
),
|
| 54 |
+
nn.Sequential(nn.Sequential(
|
| 55 |
+
nn.Conv2d(256, c * 2, kernel_size=3,
|
| 56 |
+
stride=2, padding=1, bias=False),
|
| 57 |
+
nn.BatchNorm2d(c * 2),
|
| 58 |
+
nn.ReLU(inplace=True),
|
| 59 |
+
))
|
| 60 |
+
])
|
| 61 |
+
|
| 62 |
+
# StageModule中每个分枝发生了融合
|
| 63 |
+
# (b,c,y,x) ------+---> (b,c,y,x)
|
| 64 |
+
# (b,c*2,y/2,x/2) +---> (b,c*2,y/2,x/2)
|
| 65 |
+
self.stage2 = nn.Sequential(
|
| 66 |
+
StageModule(stage=2, output_branches=2, c=c, bn_momentum=bn_momentum)
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# (b,c,y,x) ----------> (b,c,y,x)
|
| 70 |
+
# (b,c*2,y/2,x/2) +---> (b,c*2,y/2,x/2)
|
| 71 |
+
# +---> (b,c*4,y/4,x/4)
|
| 72 |
+
self.transition2 = nn.ModuleList([
|
| 73 |
+
nn.Sequential(),
|
| 74 |
+
nn.Sequential(),
|
| 75 |
+
nn.Sequential(nn.Sequential(
|
| 76 |
+
nn.Conv2d(c * 2, c * 4, kernel_size=3,
|
| 77 |
+
stride=2, padding=1, bias=False),
|
| 78 |
+
nn.BatchNorm2d(c * 4),
|
| 79 |
+
nn.ReLU(inplace=True),
|
| 80 |
+
))
|
| 81 |
+
])
|
| 82 |
+
|
| 83 |
+
# (b,c,y,x) ------++++---> (b,c,y,x)
|
| 84 |
+
# (b,c*2,y/2,x/2) ++++---> (b,c*2,y/2,x/2)
|
| 85 |
+
# (b,c*4,y/4,x/4) ++++---> (b,c*4,y/4,x/4)
|
| 86 |
+
self.stage3 = nn.Sequential(
|
| 87 |
+
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
|
| 88 |
+
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
|
| 89 |
+
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
|
| 90 |
+
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# (b,c,y,x) ----------> (b,c,y,x)
|
| 94 |
+
# (b,c*2,y/2,x/2) ----> (b,c*2,y/2,x/2)
|
| 95 |
+
# (b,c*4,y/4,x/4) +---> (b,c*4,y/4,x/4)
|
| 96 |
+
# +---> (b,c*8,y/8,x/8)
|
| 97 |
+
self.transition3 = nn.ModuleList([
|
| 98 |
+
nn.Sequential(), # None, - Used in place of "None" because it is callable
|
| 99 |
+
nn.Sequential(), # None, - Used in place of "None" because it is callable
|
| 100 |
+
nn.Sequential(), # None, - Used in place of "None" because it is callable
|
| 101 |
+
nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights
|
| 102 |
+
nn.Conv2d(c * 4, c * 8, kernel_size=3,
|
| 103 |
+
stride=2, padding=1, bias=False),
|
| 104 |
+
nn.BatchNorm2d(c * 8),
|
| 105 |
+
nn.ReLU(inplace=True),
|
| 106 |
+
)),
|
| 107 |
+
])
|
| 108 |
+
|
| 109 |
+
# (b,c,y,x) ------+++---> (b,c,y,x)
|
| 110 |
+
# (b,c*2,y/2,x/2) +++---> (b,c*2,y/2,x/2)
|
| 111 |
+
# (b,c*4,y/4,x/4) +++---> (b,c*4,y/4,x/4)
|
| 112 |
+
# (b,c*8,y/8,x/8) +++---> (b,c*8,y/8,x/8)
|
| 113 |
+
self.stage4 = nn.Sequential(
|
| 114 |
+
StageModule(stage=4, output_branches=4, c=c, bn_momentum=bn_momentum),
|
| 115 |
+
StageModule(stage=4, output_branches=4, c=c, bn_momentum=bn_momentum),
|
| 116 |
+
StageModule(stage=4, output_branches=1, c=c, bn_momentum=bn_momentum),
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
# 取最高分辨率的结果
|
| 120 |
+
# (b,c,y,x) -> (b,nof_joints*2,y,x)
|
| 121 |
+
self.final_layer = nn.Conv2d(c, nof_joints, kernel_size=1, stride=1)
|
| 122 |
+
|
| 123 |
+
self.apply(weights_init)
|
| 124 |
+
|
| 125 |
+
def forward(self, x):
|
| 126 |
+
x = self.relu(self.bn1(self.conv1(x)))
|
| 127 |
+
x = self.relu(self.bn2(self.conv2(x)))
|
| 128 |
+
|
| 129 |
+
x = self.layer1(x)
|
| 130 |
+
x = [trans(x) for trans in self.transition1]
|
| 131 |
+
|
| 132 |
+
x = self.stage2(x)
|
| 133 |
+
x = [
|
| 134 |
+
self.transition2[0](x[0]),
|
| 135 |
+
self.transition2[1](x[1]),
|
| 136 |
+
self.transition2[2](x[1]),
|
| 137 |
+
]
|
| 138 |
+
|
| 139 |
+
x = self.stage3(x)
|
| 140 |
+
x = [
|
| 141 |
+
self.transition3[0](x[0]),
|
| 142 |
+
self.transition3[1](x[1]),
|
| 143 |
+
self.transition3[2](x[2]),
|
| 144 |
+
self.transition3[3](x[2]),
|
| 145 |
+
]
|
| 146 |
+
|
| 147 |
+
x = self.stage4(x)
|
| 148 |
+
|
| 149 |
+
x = x[0]
|
| 150 |
+
out = self.final_layer(x)
|
| 151 |
+
|
| 152 |
+
return out
|
| 153 |
+
|
| 154 |
+
def hr_w32():
|
| 155 |
+
return HRNet(32)
|
| 156 |
+
|
| 157 |
+
if __name__ == '__main__':
|
| 158 |
+
import torch
|
| 159 |
+
|
| 160 |
+
model = hr_w32()
|
| 161 |
+
x = torch.randn(1,3,256,256)
|
| 162 |
+
output = model(x)
|
| 163 |
+
print(output.size())
|
models/modules/__init__.py
ADDED
|
File without changes
|
models/modules/blocks/__init__.py
ADDED
|
File without changes
|
models/modules/blocks/basic_block.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class BasicBlock(nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
(b,c,y,x) -> (b,c,y,x)
|
| 7 |
+
"""
|
| 8 |
+
expansion = 1
|
| 9 |
+
|
| 10 |
+
def __init__(self, planes, bn_momentum=.1):
|
| 11 |
+
super(BasicBlock, self).__init__()
|
| 12 |
+
|
| 13 |
+
self.conv1 = nn.Conv2d(planes, planes, kernel_size=3,
|
| 14 |
+
stride=1, padding=1, bias=False)
|
| 15 |
+
self.bn1 = nn.BatchNorm2d(planes, momentum=bn_momentum)
|
| 16 |
+
self.relu = nn.ReLU(inplace=True)
|
| 17 |
+
|
| 18 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
|
| 19 |
+
stride=1, padding=1, bias=False)
|
| 20 |
+
self.bn2 = nn.BatchNorm2d(planes, momentum=bn_momentum)
|
| 21 |
+
|
| 22 |
+
def forward(self, x):
|
| 23 |
+
residual = x
|
| 24 |
+
|
| 25 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
| 26 |
+
out = self.bn2(self.conv2(out))
|
| 27 |
+
|
| 28 |
+
out += residual
|
| 29 |
+
return self.relu(out)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
if __name__ == '__main__':
|
| 33 |
+
import torch
|
| 34 |
+
|
| 35 |
+
model = BasicBlock(256)
|
| 36 |
+
x = torch.randn(1, 256, 128, 128)
|
| 37 |
+
print(model(x).size()) # torch.Size([1,256,128,128])
|
models/modules/blocks/bottleneck.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Bottleneck(nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
(b,c_in,y,x) -> (b,4*c_out,y,x)
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
expansion = 4
|
| 10 |
+
|
| 11 |
+
def __init__(self, inplanes, planes, downsample=None, bn_momentum=.1):
|
| 12 |
+
super(Bottleneck, self).__init__()
|
| 13 |
+
|
| 14 |
+
self.conv1 = nn.Conv2d(inplanes, planes,
|
| 15 |
+
kernel_size=1, bias=False)
|
| 16 |
+
self.bn1 = nn.BatchNorm2d(planes, momentum=bn_momentum)
|
| 17 |
+
|
| 18 |
+
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3,
|
| 19 |
+
stride=1, padding=1, bias=False)
|
| 20 |
+
self.bn2 = nn.BatchNorm2d(planes, momentum=bn_momentum)
|
| 21 |
+
|
| 22 |
+
self.conv3 = nn.Conv2d(planes, planes * self.expansion,
|
| 23 |
+
kernel_size=1, bias=False)
|
| 24 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion,
|
| 25 |
+
momentum=bn_momentum)
|
| 26 |
+
|
| 27 |
+
self.relu = nn.ReLU(inplace=True)
|
| 28 |
+
self.downsample = downsample
|
| 29 |
+
|
| 30 |
+
def forward(self, x):
|
| 31 |
+
residual = x
|
| 32 |
+
|
| 33 |
+
out = self.relu(self.bn1(self.conv1(x)))
|
| 34 |
+
out = self.relu(self.bn2(self.conv2(out)))
|
| 35 |
+
out = self.bn3(self.conv3(out))
|
| 36 |
+
|
| 37 |
+
if self.downsample is not None:
|
| 38 |
+
residual = self.downsample(x)
|
| 39 |
+
|
| 40 |
+
out += residual
|
| 41 |
+
return self.relu(out)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
if __name__ == '__main__':
|
| 45 |
+
import torch
|
| 46 |
+
|
| 47 |
+
downsample = nn.Sequential(
|
| 48 |
+
nn.Conv2d(64, 256, kernel_size=1, stride=1, bias=False),
|
| 49 |
+
nn.BatchNorm2d(256),
|
| 50 |
+
)
|
| 51 |
+
model = Bottleneck(64, 64, downsample=downsample)
|
| 52 |
+
x = torch.randn(1, 64, 128, 128)
|
| 53 |
+
print(model(x).size()) # torch.Size([1,256,128,128])
|
| 54 |
+
|
| 55 |
+
model = Bottleneck(256,64)
|
| 56 |
+
x = torch.randn(1,256,128,128)
|
| 57 |
+
print(model(x).size()) # torch.Size([2,256,128,128])
|
models/modules/stage_module.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
from models.modules.blocks.basic_block import BasicBlock
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class StageModule(nn.Module):
|
| 7 |
+
|
| 8 |
+
def __init__(self, stage, output_branches, c, bn_momentum):
|
| 9 |
+
super(StageModule, self).__init__()
|
| 10 |
+
|
| 11 |
+
self.stage = stage
|
| 12 |
+
self.output_branches = output_branches
|
| 13 |
+
|
| 14 |
+
# 得到stage对应数量的分枝
|
| 15 |
+
# 例如stage=3,c=32时
|
| 16 |
+
# i = 0,1,2
|
| 17 |
+
# i = 0 -> 4*BasicBlock(32)
|
| 18 |
+
# i = 1 -> 4*BasicBlock(64)
|
| 19 |
+
# i = 2 -> 4*BasicBlock(128)
|
| 20 |
+
#
|
| 21 |
+
# -+--- 4*BasicBlock(32) ---->
|
| 22 |
+
# +--- 4*BasicBlock(64) ---->
|
| 23 |
+
# +--- 4*BasicBlock(128) --->
|
| 24 |
+
self.branches = nn.ModuleList()
|
| 25 |
+
for i in range(self.stage):
|
| 26 |
+
w = c * (2**i)
|
| 27 |
+
branch = nn.Sequential(
|
| 28 |
+
BasicBlock(w, bn_momentum=bn_momentum),
|
| 29 |
+
BasicBlock(w, bn_momentum=bn_momentum),
|
| 30 |
+
BasicBlock(w, bn_momentum=bn_momentum),
|
| 31 |
+
BasicBlock(w, bn_momentum=bn_momentum),
|
| 32 |
+
)
|
| 33 |
+
self.branches.append(branch)
|
| 34 |
+
|
| 35 |
+
self.fuse_layers = nn.ModuleList()
|
| 36 |
+
|
| 37 |
+
# 得到i*j个输出分枝,其中第(i,j)个输出分枝代表第j个分枝向第i个输出变换的输出分枝
|
| 38 |
+
# i<j,则输出分枝的通道数小于分枝i的通道数,作上采样
|
| 39 |
+
# i>j,则输出分枝的通道数大于分枝i的通道数,作下采样
|
| 40 |
+
# +---output branch 0(c=32)---->
|
| 41 |
+
# +(upsample)
|
| 42 |
+
# ---branch 1(c=64)---+---output branch 1(c=64)---->
|
| 43 |
+
# +(downsample)
|
| 44 |
+
# +---output branch 2(c=128)--->
|
| 45 |
+
# 对于每一个输出分枝i
|
| 46 |
+
for i in range(self.output_branches):
|
| 47 |
+
self.fuse_layers.append(nn.ModuleList())
|
| 48 |
+
|
| 49 |
+
# 对于每一个分枝j
|
| 50 |
+
for j in range(self.stage):
|
| 51 |
+
|
| 52 |
+
# 如果分枝与输出分枝相对应,直接输出
|
| 53 |
+
if i == j:
|
| 54 |
+
self.fuse_layers[-1].append(nn.Sequential())
|
| 55 |
+
|
| 56 |
+
# 如果输出分枝编号小于分枝编号,则上采样后输出
|
| 57 |
+
elif i < j:
|
| 58 |
+
self.fuse_layers[-1].append(nn.Sequential(
|
| 59 |
+
nn.Conv2d(c * (2**j), c * (2**i), kernel_size=1,
|
| 60 |
+
stride=1, bias=False),
|
| 61 |
+
nn.BatchNorm2d(c * (2**i)),
|
| 62 |
+
nn.Upsample(scale_factor=(2.**(j-i))),
|
| 63 |
+
))
|
| 64 |
+
|
| 65 |
+
# 如果输出分枝编号大于分枝编号,则下采样后输出
|
| 66 |
+
elif i > j:
|
| 67 |
+
ops = []
|
| 68 |
+
for _ in range(i - j - 1):
|
| 69 |
+
ops.append(nn.Sequential(
|
| 70 |
+
nn.Conv2d(c * (2**j), c * (2**j), kernel_size=3,
|
| 71 |
+
stride=2, padding=1, bias=False),
|
| 72 |
+
nn.BatchNorm2d(c * (2**j)),
|
| 73 |
+
nn.ReLU(inplace=True),
|
| 74 |
+
))
|
| 75 |
+
ops.append(nn.Sequential(
|
| 76 |
+
nn.Conv2d(c * (2**j), c * (2**i), kernel_size=3,
|
| 77 |
+
stride=2, padding=1, bias=False),
|
| 78 |
+
nn.BatchNorm2d(c * (2**i)),
|
| 79 |
+
))
|
| 80 |
+
self.fuse_layers[-1].append(nn.Sequential(*ops))
|
| 81 |
+
|
| 82 |
+
self.relu = nn.ReLU(inplace=True)
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
# 将x经过每个分枝
|
| 86 |
+
x = [branch(b) for branch, b in zip(self.branches, x)]
|
| 87 |
+
|
| 88 |
+
x_fused = []
|
| 89 |
+
# 对于每个输出分枝
|
| 90 |
+
for i in range(len(self.fuse_layers)):
|
| 91 |
+
# 对于每个分枝
|
| 92 |
+
for j in range(len(self.branches)):
|
| 93 |
+
# 如果是第0个分枝,则将经过第0个分枝的x经过第i个输出分枝
|
| 94 |
+
if j == 0:
|
| 95 |
+
x_fused.append(self.fuse_layers[i][0](x[0]))
|
| 96 |
+
# 否则,将经过第j个分枝的x经过第i个输出分枝,与之前第i个输出分枝的结果相加
|
| 97 |
+
else:
|
| 98 |
+
x_fused[i] = x_fused[i] + self.fuse_layers[i][j](x[j])
|
| 99 |
+
|
| 100 |
+
# 每个输出分枝的结果经过ReLU
|
| 101 |
+
for i in range(len(x_fused)):
|
| 102 |
+
x_fused[i] = self.relu(x_fused[i])
|
| 103 |
+
|
| 104 |
+
return x_fused
|
models/modules/stem.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Stem(nn.Module):
|
| 5 |
+
"""
|
| 6 |
+
Stem模块进行1/4的下采样,并将通道数变为64
|
| 7 |
+
(b,3,y,x) -> (b,64,y/4,x/4)
|
| 8 |
+
"""
|
| 9 |
+
def __init__(self, bn_momentum=.1):
|
| 10 |
+
super(Stem, self).__init__()
|
| 11 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
|
| 12 |
+
stride=2, padding=1, bias=False)
|
| 13 |
+
self.bn1 = nn.BatchNorm2d(64, momentum=bn_momentum)
|
| 14 |
+
self.conv2 = nn.Conv2d(64, 64, kernel_size=3,
|
| 15 |
+
stride=2, padding=1, bias=False)
|
| 16 |
+
self.bn2 = nn.BatchNorm2d(64, momentum=bn_momentum)
|
| 17 |
+
self.relu = nn.ReLU(inplace=True)
|
| 18 |
+
|
| 19 |
+
def forward(self, x):
|
| 20 |
+
out = self.bn1(self.conv1(x))
|
| 21 |
+
out = self.bn2(self.conv2(out))
|
| 22 |
+
return self.relu(out)
|
| 23 |
+
|
| 24 |
+
if __name__ == '__main__':
|
| 25 |
+
import torch
|
| 26 |
+
|
| 27 |
+
model = Stem()
|
| 28 |
+
x = torch.randn(1,3,128,64)
|
| 29 |
+
print(model(x).size()) # torch.Size([1,64,32,16])
|
requirements.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Cython 3.0.5
|
| 2 |
+
gradio 4.8.0
|
| 3 |
+
gradio_client 0.7.1
|
| 4 |
+
huggingface-hub 0.19.4
|
| 5 |
+
imageio 2.33.0
|
| 6 |
+
numpy 1.24.3
|
| 7 |
+
opencv-python 4.8.1.78
|
| 8 |
+
opendatalab 0.0.10
|
| 9 |
+
Pillow 10.0.1
|
| 10 |
+
pip 23.3
|
| 11 |
+
pycocotools 2.0.7
|
| 12 |
+
scikit-image 0.21.0
|
| 13 |
+
scipy 1.10.1
|
| 14 |
+
torch 1.12.0+cu113
|
| 15 |
+
torchaudio 0.12.0+cu113
|
| 16 |
+
torchvision 0.13.0+cu113
|
| 17 |
+
tqdm 4.65.2
|
tool_utils.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import matplotlib.image as mpimg
|
| 4 |
+
import cv2
|
| 5 |
+
import skimage
|
| 6 |
+
import torch
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
joints = [
|
| 10 |
+
'left ankle',
|
| 11 |
+
'left knee',
|
| 12 |
+
'left hip',
|
| 13 |
+
'right hip',
|
| 14 |
+
'right knee',
|
| 15 |
+
'right ankle',
|
| 16 |
+
'belly',
|
| 17 |
+
'chest',
|
| 18 |
+
'neck',
|
| 19 |
+
'head',
|
| 20 |
+
'left wrist',
|
| 21 |
+
'left elbow',
|
| 22 |
+
'left shoulder',
|
| 23 |
+
'right shoulder',
|
| 24 |
+
'right elbow',
|
| 25 |
+
'right wrist'
|
| 26 |
+
]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def generate_heatmap(heatmap, pt, sigma=(33, 33), sigma_valu=7):
|
| 30 |
+
'''
|
| 31 |
+
:param heatmap: should be a np zeros array with shape (H,W) (only i channel), not (H,W,1)
|
| 32 |
+
:param pt: point coords, np array
|
| 33 |
+
:param sigma: should be a tuple with odd values (obsolete)
|
| 34 |
+
:param sigma_valu: vaalue for gaussian blur
|
| 35 |
+
:return: a np array of one joint heatmap with shape (H,W)
|
| 36 |
+
|
| 37 |
+
This function is obsolete, use 'generate_heatmaps()' instead.
|
| 38 |
+
'''
|
| 39 |
+
heatmap[int(pt[1])][int(pt[0])] = 1
|
| 40 |
+
# heatmap = cv2.GaussianBlur(heatmap, sigma, 0) #(H,W,1) -> (H,W)
|
| 41 |
+
heatmap = skimage.filters.gaussian(
|
| 42 |
+
heatmap, sigma=sigma_valu) # (H,W,1) -> (H,W)
|
| 43 |
+
am = np.amax(heatmap)
|
| 44 |
+
heatmap = heatmap/am
|
| 45 |
+
return heatmap
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def generate_heatmaps(img, pts, sigma=(33, 33), sigma_valu=7):
|
| 49 |
+
'''
|
| 50 |
+
:param img: np arrray img, (H,W,C)
|
| 51 |
+
:param pts: joint points coords, np array, same resolu as img
|
| 52 |
+
:param sigma: should be a tuple with odd values (obsolete)
|
| 53 |
+
:param sigma_valu: vaalue for gaussian blur
|
| 54 |
+
:return: np array heatmaps, (H,W,num_pts)
|
| 55 |
+
'''
|
| 56 |
+
H, W = img.shape[0], img.shape[1]
|
| 57 |
+
num_pts = pts.shape[0]
|
| 58 |
+
heatmaps = np.zeros((H, W, num_pts))
|
| 59 |
+
for i, pt in enumerate(pts):
|
| 60 |
+
# Filter unavailable heatmaps
|
| 61 |
+
if pt[0] == 0 and pt[1] == 0:
|
| 62 |
+
continue
|
| 63 |
+
# Filter some points out of the image
|
| 64 |
+
if pt[0] >= W:
|
| 65 |
+
pt[0] = W-1
|
| 66 |
+
if pt[1] >= H:
|
| 67 |
+
pt[1] = H-1
|
| 68 |
+
heatmap = heatmaps[:, :, i]
|
| 69 |
+
heatmap[int(pt[1])][int(pt[0])] = 1
|
| 70 |
+
# heatmap = cv2.GaussianBlur(heatmap, sigma, 0) #(H,W,1) -> (H,W)
|
| 71 |
+
heatmap = skimage.filters.gaussian(
|
| 72 |
+
heatmap, sigma=sigma_valu) # (H,W,1) -> (H,W)
|
| 73 |
+
am = np.amax(heatmap)
|
| 74 |
+
heatmap = heatmap / am
|
| 75 |
+
heatmaps[:, :, i] = heatmap
|
| 76 |
+
return heatmaps
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def load_image(path_image):
|
| 80 |
+
img = mpimg.imread(path_image)
|
| 81 |
+
# Return a np array (H,W,C)
|
| 82 |
+
return img
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def crop(img, ele_anno, use_randscale=True, use_randflipLR=False, use_randcolor=False):
|
| 86 |
+
'''
|
| 87 |
+
:param img: np array of the origin image, (H,W,C)
|
| 88 |
+
:param ele_anno: one element of json annotation
|
| 89 |
+
:return: img_crop, ary_pts_crop, c_crop after cropping
|
| 90 |
+
'''
|
| 91 |
+
|
| 92 |
+
H, W = img.shape[0], img.shape[1]
|
| 93 |
+
s = ele_anno['scale_provided']
|
| 94 |
+
c = ele_anno['objpos']
|
| 95 |
+
|
| 96 |
+
# Adjust center and scale
|
| 97 |
+
if c[0] != -1:
|
| 98 |
+
c[1] = c[1] + 15 * s
|
| 99 |
+
s = s * 1.25
|
| 100 |
+
ary_pts = np.array(ele_anno['joint_self']) # (16, 3)
|
| 101 |
+
ary_pts_temp = ary_pts[np.any(ary_pts != [0, 0, 0], axis=1)]
|
| 102 |
+
|
| 103 |
+
if use_randscale:
|
| 104 |
+
scale_rand = np.random.uniform(low=1.0, high=3.0)
|
| 105 |
+
else:
|
| 106 |
+
scale_rand = 1
|
| 107 |
+
|
| 108 |
+
W_min = max(np.amin(ary_pts_temp, axis=0)[0] - s * 15 * scale_rand, 0)
|
| 109 |
+
H_min = max(np.amin(ary_pts_temp, axis=0)[1] - s * 15 * scale_rand, 0)
|
| 110 |
+
W_max = min(np.amax(ary_pts_temp, axis=0)[0] + s * 15 * scale_rand, W)
|
| 111 |
+
H_max = min(np.amax(ary_pts_temp, axis=0)[1] + s * 15 * scale_rand, H)
|
| 112 |
+
W_len = W_max - W_min
|
| 113 |
+
H_len = H_max - H_min
|
| 114 |
+
window_len = max(H_len, W_len)
|
| 115 |
+
pad_updown = (window_len - H_len)/2
|
| 116 |
+
pad_leftright = (window_len - W_len)/2
|
| 117 |
+
|
| 118 |
+
# Calculate 4 corner position
|
| 119 |
+
W_low = max((W_min - pad_leftright), 0)
|
| 120 |
+
W_high = min((W_max + pad_leftright), W)
|
| 121 |
+
H_low = max((H_min - pad_updown), 0)
|
| 122 |
+
H_high = min((H_max + pad_updown), H)
|
| 123 |
+
|
| 124 |
+
# Update joint points and center
|
| 125 |
+
ary_pts_crop = np.where(
|
| 126 |
+
ary_pts == [0, 0, 0], ary_pts, ary_pts - np.array([W_low, H_low, 0]))
|
| 127 |
+
c_crop = c - np.array([W_low, H_low])
|
| 128 |
+
|
| 129 |
+
img_crop = img[int(H_low):int(H_high), int(W_low):int(W_high), :]
|
| 130 |
+
|
| 131 |
+
# Pad when H, W different
|
| 132 |
+
H_new, W_new = img_crop.shape[0], img_crop.shape[1]
|
| 133 |
+
window_len_new = max(H_new, W_new)
|
| 134 |
+
pad_updown_new = int((window_len_new - H_new)/2)
|
| 135 |
+
pad_leftright_new = int((window_len_new - W_new)/2)
|
| 136 |
+
|
| 137 |
+
# ReUpdate joint points and center (because of the padding)
|
| 138 |
+
ary_pts_crop = np.where(ary_pts_crop == [
|
| 139 |
+
0, 0, 0], ary_pts_crop, ary_pts_crop + np.array([pad_leftright_new, pad_updown_new, 0]))
|
| 140 |
+
c_crop = c_crop + np.array([pad_leftright_new, pad_updown_new])
|
| 141 |
+
|
| 142 |
+
img_crop = cv2.copyMakeBorder(img_crop, pad_updown_new, pad_updown_new,
|
| 143 |
+
pad_leftright_new, pad_leftright_new, cv2.BORDER_CONSTANT, value=0)
|
| 144 |
+
|
| 145 |
+
# change dtype and num scale
|
| 146 |
+
img_crop = img_crop / 255.
|
| 147 |
+
img_crop = img_crop.astype(np.float64)
|
| 148 |
+
|
| 149 |
+
if use_randflipLR:
|
| 150 |
+
flip = np.random.random() > 0.5
|
| 151 |
+
# print('rand_flipLR', flip)
|
| 152 |
+
if flip:
|
| 153 |
+
# (H,W,C)
|
| 154 |
+
img_crop = np.flip(img_crop, 1)
|
| 155 |
+
# Calculate flip pts, remember to filter [0,0] which is no available heatmap
|
| 156 |
+
ary_pts_crop = np.where(ary_pts_crop == [0, 0, 0], ary_pts_crop,
|
| 157 |
+
[window_len_new, 0, 0] + ary_pts_crop * [-1, 1, 0])
|
| 158 |
+
c_crop = [window_len_new, 0] + c_crop * [-1, 1]
|
| 159 |
+
# Rearrange pts
|
| 160 |
+
ary_pts_crop = np.concatenate(
|
| 161 |
+
(ary_pts_crop[5::-1], ary_pts_crop[6:10], ary_pts_crop[15:9:-1]))
|
| 162 |
+
|
| 163 |
+
if use_randcolor:
|
| 164 |
+
randcolor = np.random.random() > 0.5
|
| 165 |
+
# print('rand_color', randcolor)
|
| 166 |
+
if randcolor:
|
| 167 |
+
img_crop[...,
|
| 168 |
+
0] *= np.clip(np.random.uniform(low=0.8, high=1.2), 0., 1.)
|
| 169 |
+
img_crop[...,
|
| 170 |
+
1] *= np.clip(np.random.uniform(low=0.8, high=1.2), 0., 1.)
|
| 171 |
+
img_crop[...,
|
| 172 |
+
2] *= np.clip(np.random.uniform(low=0.8, high=1.2), 0., 1.)
|
| 173 |
+
|
| 174 |
+
return img_crop, ary_pts_crop, c_crop
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def change_resolu(img, pts, c, resolu_out=(256, 256)):
|
| 178 |
+
'''
|
| 179 |
+
:param img: np array of the origin image
|
| 180 |
+
:param pts: joint points np array corresponding to the image, same resolu as img
|
| 181 |
+
:param c: center
|
| 182 |
+
:param resolu_out: a list or tuple
|
| 183 |
+
:return: img_out, pts_out, c_out under resolu_out
|
| 184 |
+
'''
|
| 185 |
+
H_in = img.shape[0]
|
| 186 |
+
W_in = img.shape[1]
|
| 187 |
+
H_out = resolu_out[0]
|
| 188 |
+
W_out = resolu_out[1]
|
| 189 |
+
H_scale = H_in/H_out
|
| 190 |
+
W_scale = W_in/W_out
|
| 191 |
+
|
| 192 |
+
pts_out = pts/np.array([W_scale, H_scale, 1])
|
| 193 |
+
c_out = c/np.array([W_scale, H_scale])
|
| 194 |
+
img_out = skimage.transform.resize(img, tuple(resolu_out))
|
| 195 |
+
|
| 196 |
+
return img_out, pts_out, c_out
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def heatmaps_to_coords(heatmaps, resolu_out=[64, 64], prob_threshold=0.2):
|
| 200 |
+
'''
|
| 201 |
+
:param heatmaps: tensor with shape (64,64,16)
|
| 202 |
+
:param resolu_out: output resolution list
|
| 203 |
+
:return coord_joints: np array, shape (16,2)
|
| 204 |
+
'''
|
| 205 |
+
|
| 206 |
+
num_joints = heatmaps.shape[2]
|
| 207 |
+
# Resize
|
| 208 |
+
heatmaps = skimage.transform.resize(heatmaps, tuple(resolu_out))
|
| 209 |
+
|
| 210 |
+
coord_joints = np.zeros((num_joints, 3))
|
| 211 |
+
for i in range(num_joints):
|
| 212 |
+
heatmap = heatmaps[..., i]
|
| 213 |
+
max = np.max(heatmap)
|
| 214 |
+
# Only keep points larger than a threshold
|
| 215 |
+
if max >= prob_threshold:
|
| 216 |
+
idx = np.where(heatmap == max)
|
| 217 |
+
H = idx[0][0]
|
| 218 |
+
W = idx[1][0]
|
| 219 |
+
else:
|
| 220 |
+
H = 0
|
| 221 |
+
W = 0
|
| 222 |
+
coord_joints[i] = [W, H, max]
|
| 223 |
+
return coord_joints
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def show_heatmaps(img, heatmaps, c=np.zeros((2)), num_fig=1):
|
| 227 |
+
'''
|
| 228 |
+
:param img: np array (H,W,3)
|
| 229 |
+
:param heatmaps: np array (H,W,num_pts)
|
| 230 |
+
:param c: center, np array (2,)
|
| 231 |
+
'''
|
| 232 |
+
H, W = img.shape[0], img.shape[1]
|
| 233 |
+
|
| 234 |
+
if heatmaps.shape[0] != H:
|
| 235 |
+
heatmaps = skimage.transform.resize(heatmaps, (H, W))
|
| 236 |
+
|
| 237 |
+
plt.figure(num_fig)
|
| 238 |
+
for i in range(heatmaps.shape[2] + 1):
|
| 239 |
+
plt.subplot(4, 5, i + 1)
|
| 240 |
+
if i == 0:
|
| 241 |
+
plt.title('Origin')
|
| 242 |
+
else:
|
| 243 |
+
plt.title(joints[i-1])
|
| 244 |
+
|
| 245 |
+
if i == 0:
|
| 246 |
+
plt.imshow(img)
|
| 247 |
+
else:
|
| 248 |
+
plt.imshow(heatmaps[:, :, i - 1])
|
| 249 |
+
|
| 250 |
+
plt.axis('off')
|
| 251 |
+
plt.subplot(4, 5, 20)
|
| 252 |
+
plt.axis('off')
|
| 253 |
+
plt.show()
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def heatmap2rgb(heatmap):
|
| 257 |
+
"""
|
| 258 |
+
: heatmap: (h,w)
|
| 259 |
+
"""
|
| 260 |
+
|
| 261 |
+
heatmap = heatmap.detach().cpu().numpy()
|
| 262 |
+
|
| 263 |
+
# plt.figure(figsize=(1,1))
|
| 264 |
+
# plt.axis('off')
|
| 265 |
+
# plt.imshow(heatmap)
|
| 266 |
+
# plt.savefig('tmp/tmp.jpg', bbox_inches='tight', pad_inches=0, dpi=70)
|
| 267 |
+
# plt.close()
|
| 268 |
+
# plt.clf()
|
| 269 |
+
|
| 270 |
+
# img = Image.open('tmp/tmp.jpg')
|
| 271 |
+
cm = plt.get_cmap('jet')
|
| 272 |
+
normed_data = (heatmap - np.min(heatmap)) / (np.max(heatmap) - np.min(heatmap + 1e-8))
|
| 273 |
+
mapped_data = cm(normed_data)
|
| 274 |
+
|
| 275 |
+
# (h,w,c)
|
| 276 |
+
# img = np.array(img)
|
| 277 |
+
img = np.array(mapped_data)
|
| 278 |
+
img = img[:,:,:3]
|
| 279 |
+
img = torch.tensor(img).permute(2, 0, 1)
|
| 280 |
+
|
| 281 |
+
return img
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
def heatmaps2rgb(heatmaps):
|
| 285 |
+
"""
|
| 286 |
+
: heatmaps: (b,h,w)
|
| 287 |
+
"""
|
| 288 |
+
|
| 289 |
+
out_imgs = []
|
| 290 |
+
for heatmap in heatmaps:
|
| 291 |
+
out_imgs.append(heatmap2rgb(heatmap))
|
| 292 |
+
|
| 293 |
+
return torch.stack(out_imgs)
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
# def draw_joints(img, pts):
|
| 297 |
+
# scores = pts[:,2]
|
| 298 |
+
# pts = np.array(pts).astype(int)
|
| 299 |
+
|
| 300 |
+
# for i in range(pts.shape[0]):
|
| 301 |
+
# if pts[i, 0] != 0 and pts[i, 1] != 0:
|
| 302 |
+
# img = cv2.circle(img, (pts[i, 0], pts[i, 1]), radius=3,
|
| 303 |
+
# color=(255, 0, 0), thickness=-1)
|
| 304 |
+
# print('img',img.max(),img.min())
|
| 305 |
+
# # img = cv2.putText(img, f'{joints[i]}: {scores[i]:.2f}', (
|
| 306 |
+
# # pts[i, 0]+5, pts[i, 1]-5), cv2.FONT_HERSHEY_SIMPLEX, .25, (255, 0, 0))
|
| 307 |
+
|
| 308 |
+
# # Left arm
|
| 309 |
+
# for i in range(10, 13-1):
|
| 310 |
+
# if pts[i, 0] != 0 and pts[i, 1] != 0 and pts[i+1, 0] != 0 and pts[i+1, 1] != 0:
|
| 311 |
+
# img = cv2.line(img, (pts[i, 0], pts[i, 1]), (pts[i+1, 0],
|
| 312 |
+
# pts[i+1, 1]), color=(255, 0, 0), thickness=1)
|
| 313 |
+
|
| 314 |
+
# # Right arm
|
| 315 |
+
# for i in range(13, 16-1):
|
| 316 |
+
# if pts[i, 0] != 0 and pts[i, 1] != 0 and pts[i+1, 0] != 0 and pts[i+1, 1] != 0:
|
| 317 |
+
# img = cv2.line(img, (pts[i, 0], pts[i, 1]), (pts[i+1, 0],
|
| 318 |
+
# pts[i+1, 1]), color=(255, 0, 0), thickness=1)
|
| 319 |
+
|
| 320 |
+
# # Left leg
|
| 321 |
+
# for i in range(0, 3-1):
|
| 322 |
+
# if pts[i, 0] != 0 and pts[i, 1] != 0 and pts[i+1, 0] != 0 and pts[i+1, 1] != 0:
|
| 323 |
+
# img = cv2.line(img, (pts[i, 0], pts[i, 1]), (pts[i+1, 0],
|
| 324 |
+
# pts[i+1, 1]), color=(255, 0, 0), thickness=1)
|
| 325 |
+
# # Right leg
|
| 326 |
+
# for i in range(3, 6-1):
|
| 327 |
+
# if pts[i, 0] != 0 and pts[i, 1] != 0 and pts[i+1, 0] != 0 and pts[i+1, 1] != 0:
|
| 328 |
+
# img = cv2.line(img, (pts[i, 0], pts[i, 1]), (pts[i+1, 0],
|
| 329 |
+
# pts[i+1, 1]), color=(255, 0, 0), thickness=1)
|
| 330 |
+
|
| 331 |
+
# # Body
|
| 332 |
+
# for i in range(6, 10-1):
|
| 333 |
+
# if pts[i, 0] != 0 and pts[i, 1] != 0 and pts[i+1, 0] != 0 and pts[i+1, 1] != 0:
|
| 334 |
+
# img = cv2.line(img, (pts[i, 0], pts[i, 1]), (pts[i+1, 0],
|
| 335 |
+
# pts[i+1, 1]), color=(255, 0, 0), thickness=1)
|
| 336 |
+
|
| 337 |
+
# if pts[2, 0] != 0 and pts[2, 1] != 0 and pts[3, 0] != 0 and pts[3, 1] != 0:
|
| 338 |
+
# img = cv2.line(img, (pts[2, 0], pts[2, 1]), (pts[2+1, 0],
|
| 339 |
+
# pts[2+1, 1]), color=(255, 0, 0), thickness=1)
|
| 340 |
+
# if pts[12, 0] != 0 and pts[12, 1] != 0 and pts[13, 0] != 0 and pts[13, 1] != 0:
|
| 341 |
+
# img = cv2.line(img, (pts[12, 0], pts[12, 1]), (pts[12+1, 0],
|
| 342 |
+
# pts[12+1, 1]), color=(255, 0, 0), thickness=1)
|
| 343 |
+
|
| 344 |
+
# return img
|
| 345 |
+
def draw_joints(img, pts):
|
| 346 |
+
# Convert the image to the range [0, 255] for visualization
|
| 347 |
+
img_visualization = (img * 255).astype(np.uint8)
|
| 348 |
+
|
| 349 |
+
# Draw lines for the body parts
|
| 350 |
+
for i in range(10, 13 - 1):
|
| 351 |
+
draw_line(img_visualization, pts[i], pts[i + 1])
|
| 352 |
+
|
| 353 |
+
for i in range(13, 16 - 1):
|
| 354 |
+
draw_line(img_visualization, pts[i], pts[i + 1])
|
| 355 |
+
|
| 356 |
+
for i in range(0, 3 - 1):
|
| 357 |
+
draw_line(img_visualization, pts[i], pts[i + 1])
|
| 358 |
+
|
| 359 |
+
for i in range(3, 6 - 1):
|
| 360 |
+
draw_line(img_visualization, pts[i], pts[i + 1])
|
| 361 |
+
|
| 362 |
+
for i in range(6, 10 - 1):
|
| 363 |
+
draw_line(img_visualization, pts[i], pts[i + 1])
|
| 364 |
+
|
| 365 |
+
draw_line(img_visualization, pts[2], pts[3])
|
| 366 |
+
draw_line(img_visualization, pts[12], pts[13])
|
| 367 |
+
|
| 368 |
+
return img_visualization / 255.0
|
| 369 |
+
|
| 370 |
+
def draw_line(img, pt1, pt2):
|
| 371 |
+
if pt1[0] != 0 and pt1[1] != 0 and pt2[0] != 0 and pt2[1] != 0:
|
| 372 |
+
cv2.line(img, (int(pt1[0]), int(pt1[1])), (int(pt2[0]), int(pt2[1])), color=(255, 0, 0), thickness=1)
|