Spaces:
Build error
Build error
| # -------------------------------------------------------- | |
| # SiamMask | |
| # Licensed under The MIT License | |
| # Written by Qiang Wang (wangqiang2015 at ia.ac.cn) | |
| # -------------------------------------------------------- | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class RPN(nn.Module): | |
| def __init__(self): | |
| super(RPN, self).__init__() | |
| def forward(self, z_f, x_f): | |
| raise NotImplementedError | |
| def template(self, template): | |
| raise NotImplementedError | |
| def track(self, search): | |
| raise NotImplementedError | |
| def param_groups(self, start_lr, feature_mult=1, key=None): | |
| if key is None: | |
| params = filter(lambda x:x.requires_grad, self.parameters()) | |
| else: | |
| params = [v for k, v in self.named_parameters() if (key in k) and v.requires_grad] | |
| params = [{'params': params, 'lr': start_lr * feature_mult}] | |
| return params | |
| def conv2d_dw_group(x, kernel): | |
| batch, channel = kernel.shape[:2] | |
| x = x.view(1, batch*channel, x.size(2), x.size(3)) # 1 * (b*c) * k * k | |
| kernel = kernel.view(batch*channel, 1, kernel.size(2), kernel.size(3)) # (b*c) * 1 * H * W | |
| out = F.conv2d(x, kernel, groups=batch*channel) | |
| out = out.view(batch, channel, out.size(2), out.size(3)) | |
| return out | |
| class DepthCorr(nn.Module): | |
| def __init__(self, in_channels, hidden, out_channels, kernel_size=3): | |
| super(DepthCorr, self).__init__() | |
| # adjust layer for asymmetrical features | |
| self.conv_kernel = nn.Sequential( | |
| nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False), | |
| nn.BatchNorm2d(hidden), | |
| nn.ReLU(inplace=True), | |
| ) | |
| self.conv_search = nn.Sequential( | |
| nn.Conv2d(in_channels, hidden, kernel_size=kernel_size, bias=False), | |
| nn.BatchNorm2d(hidden), | |
| nn.ReLU(inplace=True), | |
| ) | |
| self.head = nn.Sequential( | |
| nn.Conv2d(hidden, hidden, kernel_size=1, bias=False), | |
| nn.BatchNorm2d(hidden), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(hidden, out_channels, kernel_size=1) | |
| ) | |
| def forward_corr(self, kernel, input): | |
| kernel = self.conv_kernel(kernel) | |
| input = self.conv_search(input) | |
| feature = conv2d_dw_group(input, kernel) | |
| return feature | |
| def forward(self, kernel, search): | |
| feature = self.forward_corr(kernel, search) | |
| out = self.head(feature) | |
| return out | |