yijiu's picture
feat: upload project
93a6fff
import torch
from torch import nn
from models.modules.blocks.bottleneck import Bottleneck
from models.modules.stage_module import StageModule
def weights_init(m):
if isinstance(m, nn.Conv2d):
nn.init.normal_(m.weight, std=.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
class HRNet(nn.Module):
def __init__(self, c=48, nof_joints=16, bn_momentum=.1):
super(HRNet, self).__init__()
# (b,3,y,x) -> (b,64,y,x)
self.conv1 = nn.Conv2d(3, 64, kernel_size=3,
stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64, momentum=bn_momentum)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3,
stride=2, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(64, momentum=bn_momentum)
self.relu = nn.ReLU(inplace=True)
# (b,64,y,x) -> (b,256,y,x)
downsample = nn.Sequential(
nn.Conv2d(64, 256, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(256),
)
self.layer1 = nn.Sequential(
Bottleneck(64, 64, downsample=downsample),
Bottleneck(256, 64),
Bottleneck(256, 64),
Bottleneck(256, 64),
)
# (b,256,y,x) ---+---> (b,c,y,x)
# +---> (b,c*2,y/2,x/2)
self.transition1 = nn.ModuleList([
nn.Sequential(
nn.Conv2d(256, c, kernel_size=3,
stride=1, padding=1, bias=False),
nn.BatchNorm2d(c),
nn.ReLU(inplace=True),
),
nn.Sequential(nn.Sequential(
nn.Conv2d(256, c * 2, kernel_size=3,
stride=2, padding=1, bias=False),
nn.BatchNorm2d(c * 2),
nn.ReLU(inplace=True),
))
])
# StageModule中每个分枝发生了融合
# (b,c,y,x) ------+---> (b,c,y,x)
# (b,c*2,y/2,x/2) +---> (b,c*2,y/2,x/2)
self.stage2 = nn.Sequential(
StageModule(stage=2, output_branches=2, c=c, bn_momentum=bn_momentum)
)
# (b,c,y,x) ----------> (b,c,y,x)
# (b,c*2,y/2,x/2) +---> (b,c*2,y/2,x/2)
# +---> (b,c*4,y/4,x/4)
self.transition2 = nn.ModuleList([
nn.Sequential(),
nn.Sequential(),
nn.Sequential(nn.Sequential(
nn.Conv2d(c * 2, c * 4, kernel_size=3,
stride=2, padding=1, bias=False),
nn.BatchNorm2d(c * 4),
nn.ReLU(inplace=True),
))
])
# (b,c,y,x) ------++++---> (b,c,y,x)
# (b,c*2,y/2,x/2) ++++---> (b,c*2,y/2,x/2)
# (b,c*4,y/4,x/4) ++++---> (b,c*4,y/4,x/4)
self.stage3 = nn.Sequential(
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
StageModule(stage=3, output_branches=3, c=c, bn_momentum=bn_momentum),
)
# (b,c,y,x) ----------> (b,c,y,x)
# (b,c*2,y/2,x/2) ----> (b,c*2,y/2,x/2)
# (b,c*4,y/4,x/4) +---> (b,c*4,y/4,x/4)
# +---> (b,c*8,y/8,x/8)
self.transition3 = nn.ModuleList([
nn.Sequential(), # None, - Used in place of "None" because it is callable
nn.Sequential(), # None, - Used in place of "None" because it is callable
nn.Sequential(), # None, - Used in place of "None" because it is callable
nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights
nn.Conv2d(c * 4, c * 8, kernel_size=3,
stride=2, padding=1, bias=False),
nn.BatchNorm2d(c * 8),
nn.ReLU(inplace=True),
)),
])
# (b,c,y,x) ------+++---> (b,c,y,x)
# (b,c*2,y/2,x/2) +++---> (b,c*2,y/2,x/2)
# (b,c*4,y/4,x/4) +++---> (b,c*4,y/4,x/4)
# (b,c*8,y/8,x/8) +++---> (b,c*8,y/8,x/8)
self.stage4 = nn.Sequential(
StageModule(stage=4, output_branches=4, c=c, bn_momentum=bn_momentum),
StageModule(stage=4, output_branches=4, c=c, bn_momentum=bn_momentum),
StageModule(stage=4, output_branches=1, c=c, bn_momentum=bn_momentum),
)
# 取最高分辨率的结果
# (b,c,y,x) -> (b,nof_joints*2,y,x)
self.final_layer = nn.Conv2d(c, nof_joints, kernel_size=1, stride=1)
self.apply(weights_init)
def forward(self, x):
x = self.relu(self.bn1(self.conv1(x)))
x = self.relu(self.bn2(self.conv2(x)))
x = self.layer1(x)
x = [trans(x) for trans in self.transition1]
x = self.stage2(x)
x = [
self.transition2[0](x[0]),
self.transition2[1](x[1]),
self.transition2[2](x[1]),
]
x = self.stage3(x)
x = [
self.transition3[0](x[0]),
self.transition3[1](x[1]),
self.transition3[2](x[2]),
self.transition3[3](x[2]),
]
x = self.stage4(x)
x = x[0]
out = self.final_layer(x)
return out
def hr_w32():
return HRNet(32)
if __name__ == '__main__':
import torch
model = hr_w32()
x = torch.randn(1,3,256,256)
output = model(x)
print(output.size())