|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
def patch_first_conv(model, new_in_channels, default_in_channels=3, pretrained=True): |
|
"""Change first convolution layer input channels. |
|
In case: |
|
in_channels == 1 or in_channels == 2 -> reuse original weights |
|
in_channels > 3 -> make random kaiming normal initialization |
|
""" |
|
|
|
|
|
for module in model.modules(): |
|
if isinstance(module, nn.Conv2d) and module.in_channels == default_in_channels: |
|
break |
|
|
|
weight = module.weight.detach() |
|
module.in_channels = new_in_channels |
|
|
|
if not pretrained: |
|
module.weight = nn.parameter.Parameter( |
|
torch.Tensor(module.out_channels, new_in_channels // module.groups, *module.kernel_size) |
|
) |
|
module.reset_parameters() |
|
|
|
elif new_in_channels == 1: |
|
new_weight = weight.sum(1, keepdim=True) |
|
module.weight = nn.parameter.Parameter(new_weight) |
|
|
|
else: |
|
new_weight = torch.Tensor(module.out_channels, new_in_channels // module.groups, *module.kernel_size) |
|
|
|
for i in range(new_in_channels): |
|
new_weight[:, i] = weight[:, i % default_in_channels] |
|
|
|
new_weight = new_weight * (default_in_channels / new_in_channels) |
|
module.weight = nn.parameter.Parameter(new_weight) |
|
|
|
|
|
def replace_strides_with_dilation(module, dilation_rate): |
|
"""Patch Conv2d modules replacing strides with dilation""" |
|
for mod in module.modules(): |
|
if isinstance(mod, nn.Conv2d): |
|
mod.stride = (1, 1) |
|
mod.dilation = (dilation_rate, dilation_rate) |
|
kh, kw = mod.kernel_size |
|
mod.padding = ((kh // 2) * dilation_rate, (kh // 2) * dilation_rate) |
|
|
|
|
|
if hasattr(mod, "static_padding"): |
|
mod.static_padding = nn.Identity() |
|
|