File size: 1,832 Bytes
57d1795 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch
import torch.nn as nn
from fastai.vision import *
from modules.model import _default_tfmer_cfg
from modules.resnet import resnet45
from modules.transformer import (PositionalEncoding,
TransformerEncoder,
TransformerEncoderLayer)
class ResTranformer(nn.Module):
def __init__(self, config):
super().__init__()
alpha_d = ifnone(config.model_vision_backbone_alpha_d, 1.)
self.d_model = ifnone(config.model_vision_d_model, _default_tfmer_cfg['d_model'])
self.resnet = resnet45(alpha_d, output_channels=self.d_model)
nhead = ifnone(config.model_vision_nhead, _default_tfmer_cfg['nhead'])
d_inner = ifnone(config.model_vision_d_inner, _default_tfmer_cfg['d_inner'])
dropout = ifnone(config.model_vision_dropout, _default_tfmer_cfg['dropout'])
activation = ifnone(config.model_vision_activation, _default_tfmer_cfg['activation'])
num_layers = ifnone(config.model_vision_backbone_ln, 2)
self.pos_encoder = PositionalEncoding(self.d_model, max_len=8*32)
encoder_layer = TransformerEncoderLayer(d_model=self.d_model, nhead=nhead,
dim_feedforward=d_inner, dropout=dropout, activation=activation)
self.transformer = TransformerEncoder(encoder_layer, num_layers)
def forward_transformer(self, feature):
n, c, h, w = feature.shape
feature = feature.view(n, c, -1).permute(2, 0, 1)
feature = self.pos_encoder(feature)
feature = self.transformer(feature)
feature = feature.permute(1, 2, 0).view(n, c, h, w)
return feature
def forward(self, images, **kwargs):
feature = self.resnet(images, **kwargs)
feature = self.forward_transformer(feature)
return feature |