|
|
|
import sys |
|
|
|
sys.path.append(sys.path[0].replace('models', '')) |
|
|
|
import re |
|
import logging |
|
import math |
|
import json |
|
import pathlib |
|
import numpy as np |
|
from copy import deepcopy |
|
from pathlib import Path |
|
from einops import rearrange |
|
from collections import OrderedDict |
|
from dataclasses import dataclass |
|
from typing import Tuple, Union, Callable, Optional |
|
|
|
|
|
import torch |
|
from torch import nn |
|
import torch.nn.functional as F |
|
import torchvision.models as models |
|
from torch.utils.checkpoint import checkpoint |
|
|
|
from transformers import AutoModel,BertConfig,AutoTokenizer |
|
|
|
|
|
|
|
from models.transformer_decoder import * |
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch.autograd import Function |
|
import timm |
|
|
|
class ReverseLayerF(Function): |
|
|
|
@staticmethod |
|
def forward(ctx, x, alpha): |
|
ctx.alpha = alpha |
|
return x.view_as(x) |
|
|
|
@staticmethod |
|
def backward(ctx, grad_output): |
|
return grad_output.neg() * ctx.alpha, None |
|
|
|
|
|
class DomainClassifier(nn.Module): |
|
'''一个单层分类器 带梯度反转层''' |
|
def __init__(self, domain_nums=4, feature_dims=768): |
|
super().__init__() |
|
self.domain_nums = domain_nums |
|
self.feature_dims = feature_dims |
|
self.fc = nn.Linear(feature_dims, domain_nums) |
|
|
|
def forward(self, x): |
|
reverse_x = ReverseLayerF.apply(x, 1.0) |
|
return self.fc(reverse_x) |
|
|
|
class CLP_clinical(nn.Module): |
|
def __init__(self, |
|
bert_model_name: str, |
|
embed_dim: int = 768, |
|
freeze_layers:Union[Tuple[int, int], int] = None): |
|
super().__init__() |
|
self.bert_model = self._get_bert_basemodel(bert_model_name=bert_model_name, freeze_layers=freeze_layers) |
|
self.mlp_embed = nn.Sequential( |
|
nn.Linear(embed_dim, embed_dim), |
|
nn.GELU(), |
|
nn.Linear(embed_dim, embed_dim) |
|
) |
|
self.embed_dim = embed_dim |
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
self.init_parameters() |
|
|
|
def init_parameters(self): |
|
nn.init.constant_(self.logit_scale, np.log(1 / 0.07)) |
|
for m in self.mlp_embed: |
|
if isinstance(m, nn.Linear): |
|
nn.init.normal_(m.weight, std=self.embed_dim ** -0.5) |
|
|
|
def _get_bert_basemodel(self, bert_model_name, freeze_layers=None): |
|
try: |
|
print(bert_model_name) |
|
config = BertConfig.from_pretrained(bert_model_name, output_hidden_states=True) |
|
model = AutoModel.from_pretrained(bert_model_name, config=config) |
|
print("Text feature extractor:", bert_model_name) |
|
print("bert encoder layers:",len(model.encoder.layer)) |
|
except: |
|
raise ("Invalid model name. Check the config file and pass a BERT model from transformers lybrary") |
|
|
|
if freeze_layers is not None: |
|
for layer_idx in freeze_layers: |
|
for param in list(model.encoder.layer[layer_idx].parameters()): |
|
param.requires_grad = False |
|
return model |
|
|
|
def encode_text(self, text): |
|
|
|
output = self.bert_model(input_ids = text['input_ids'],attention_mask = text['attention_mask'] ) |
|
last_hidden_state, pooler_output, hidden_states = output[0],output[1],output[2] |
|
encode_out = self.mlp_embed(pooler_output) |
|
|
|
return encode_out |
|
|
|
def forward(self, text): |
|
|
|
output = self.bert_model(input_ids = text['input_ids'],attention_mask = text['attention_mask'] ) |
|
last_hidden_state, pooler_output, hidden_states = output[0],output[1],output[2] |
|
encode_out = self.mlp_embed(pooler_output) |
|
|
|
return encode_out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelRes(nn.Module): |
|
def __init__(self, res_base_model): |
|
super(ModelRes, self).__init__() |
|
self.resnet_dict = { |
|
"resnet50": models.resnet50(pretrained=True), |
|
"resnet101": models.resnet101(pretrained=True), |
|
"resnet152": models.resnet152(pretrained=True), |
|
"resnet50_openai": None, |
|
'resnet101_openai': None, |
|
'resnet50x4_openai': None, |
|
} |
|
|
|
self.resnet = self._get_res_basemodel(res_base_model) |
|
|
|
|
|
if 'openai' in res_base_model: |
|
|
|
num_ftrs = int(self.resnet.attnpool.v_proj.in_features) |
|
self.res_features = nn.Sequential(*list(self.resnet.children())[:-1]) |
|
else: |
|
num_ftrs = int(self.resnet.fc.in_features) |
|
self.res_features = nn.Sequential(*list(self.resnet.children())[:-2]) |
|
|
|
self.res_l1 = nn.Linear(num_ftrs, num_ftrs) |
|
self.res_l2 = nn.Linear(num_ftrs, 768) |
|
|
|
def _get_res_basemodel(self, res_model_name): |
|
try: |
|
res_model = self.resnet_dict[res_model_name] |
|
print("Image feature extractor:", res_model_name) |
|
return res_model |
|
except: |
|
raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50") |
|
|
|
def forward(self, img): |
|
|
|
batch_size = img.shape[0] |
|
res_fea = self.res_features(img) |
|
|
|
|
|
res_fea = rearrange(res_fea,'b d n1 n2 -> b (n1 n2) d') |
|
h = rearrange(res_fea,'b n d -> (b n) d') |
|
x = self.res_l1(h) |
|
x = F.relu(x) |
|
x = self.res_l2(x) |
|
out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size) |
|
out_pool = torch.mean(out_emb,dim=1) |
|
return out_emb,out_pool |
|
|
|
|
|
class ModelConvNeXt(nn.Module): |
|
def __init__(self, convnext_base_model): |
|
super(ModelConvNeXt, self).__init__() |
|
self.convnext_dict = {"convnext-tiny": timm.create_model('convnextv2_tiny.fcmae_ft_in22k_in1k_384', pretrained=True, num_classes=1000), |
|
"convnext-base": timm.create_model('convnextv2_base.fcmae_ft_in22k_in1k_384', pretrained=True, num_classes=1000), |
|
} |
|
convnext = self._get_convnext_basemodel(convnext_base_model) |
|
num_ftrs = int(convnext.head.in_features) |
|
self.conv_features = nn.Sequential(*list(convnext.children())[:-2]) |
|
self.conv_l1 = nn.Linear(num_ftrs, num_ftrs) |
|
self.conv_l2 = nn.Linear(num_ftrs, 768) |
|
|
|
|
|
def _get_convnext_basemodel(self, convnext_model_name): |
|
try: |
|
convnext_model = self.convnext_dict[convnext_model_name] |
|
print("Image feature extractor:", convnext_model_name) |
|
return convnext_model |
|
except: |
|
raise ("Invalid model name. Check the config file and pass one of: convnext-tiny, convnext-small or convnext-base") |
|
|
|
def forward(self, img): |
|
|
|
batch_size = img.shape[0] |
|
conv_fea = self.conv_features(img) |
|
conv_fea = F.adaptive_avg_pool2d(conv_fea, (1, 1)) |
|
conv_fea = rearrange(conv_fea,'b d n1 n2 -> b (n1 n2) d') |
|
h = rearrange(conv_fea,'b n d -> (b n) d') |
|
x = self.conv_l1(h) |
|
x = F.relu(x) |
|
x = self.conv_l2(x) |
|
out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size) |
|
out_pool = torch.mean(out_emb,dim=1) |
|
return out_emb,out_pool |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ModelEfficientV2(nn.Module): |
|
def __init__(self, efficientv2_base_model): |
|
super(ModelEfficientV2, self).__init__() |
|
self.efficientv2_dict = {"efficientnet_v2_s": models.efficientnet_v2_s(weights='EfficientNet_V2_S_Weights.IMAGENET1K_V1'),} |
|
self.efficientv2_model = self._get_efficientv2_basemodel(efficientv2_base_model) |
|
num_ftrs = int(self.efficientv2_model.classifier[-1].in_features) |
|
self.efficientv2_features = nn.Sequential(*list(self.efficientv2_model.children())[:-2]) |
|
self.efficientv2_l1 = nn.Linear(num_ftrs, num_ftrs) |
|
self.efficientv2_l2 = nn.Linear(num_ftrs, 768) |
|
|
|
|
|
def _get_efficientv2_basemodel(self, efficientv2_model_name): |
|
try: |
|
efficientv2_model = self.efficientv2_dict[efficientv2_model_name] |
|
print("Image feature extractor:", efficientv2_model_name) |
|
return efficientv2_model |
|
except: |
|
raise ("Invalid model name. Check the config file and pass one of: efficientnetv2_rw_s") |
|
|
|
def forward(self, img): |
|
batch_size = img.shape[0] |
|
efficientv2_fea = self.efficientv2_features(img) |
|
|
|
|
|
efficientv2_fea = rearrange(efficientv2_fea,'b d n1 n2 -> b (n1 n2) d') |
|
|
|
h = rearrange(efficientv2_fea,'b n d -> (b n) d') |
|
|
|
x = self.efficientv2_l1(h) |
|
x = F.relu(x) |
|
x = self.efficientv2_l2(x) |
|
|
|
out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size) |
|
out_pool = torch.mean(out_emb,dim=1) |
|
return out_emb,out_pool |
|
|
|
|
|
|
|
class ModelDense(nn.Module): |
|
def __init__(self, dense_base_model): |
|
super(ModelDense, self).__init__() |
|
|
|
self.densenet_dict = {"densenet121": models.densenet121(weights='DenseNet121_Weights.IMAGENET1K_V1'), |
|
"densenet161": models.densenet161(weights='DenseNet161_Weights.IMAGENET1K_V1'), |
|
"densenet201": models.densenet201(weights='DenseNet201_Weights.IMAGENET1K_V1'),} |
|
self.densenet = self._get_dense_basemodel(dense_base_model) |
|
num_ftrs = int(self.densenet.classifier.in_features) |
|
self.dense_features = self.densenet.features |
|
self.dense_l1 = nn.Linear(num_ftrs, num_ftrs) |
|
self.dense_l2 = nn.Linear(num_ftrs, 768) |
|
|
|
def _get_dense_basemodel(self, dense_base_model): |
|
try: |
|
dense_model = self.densenet_dict[dense_base_model] |
|
print("Image feature extractor:", dense_base_model) |
|
return dense_model |
|
except: |
|
raise ("Invalid model name. Check the config file and pass one of: densenet121 or densenet161") |
|
|
|
def forward(self, img): |
|
batch_size = img.shape[0] |
|
dense_fea = self.dense_features(img) |
|
dense_fea = rearrange(dense_fea,'b d n1 n2 -> b (n1 n2) d') |
|
h = rearrange(dense_fea,'b n d -> (b n) d') |
|
x = self.dense_l1(h) |
|
x = F.relu(x) |
|
x = self.dense_l2(x) |
|
out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size) |
|
out_pool = torch.mean(out_emb,dim=1) |
|
return out_emb,out_pool |
|
|
|
class TQN_Model(nn.Module): |
|
def __init__(self, |
|
embed_dim: int = 768, |
|
class_num: int = 1, |
|
lam: list = [1, 0] |
|
): |
|
super().__init__() |
|
self.d_model = embed_dim |
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
|
|
|
|
decoder_layerV1 = TransformerDecoderLayerV1(self.d_model, 4, 1024, |
|
0.1, 'relu', True, lam) |
|
self.decoder_norm = nn.LayerNorm(self.d_model) |
|
|
|
|
|
self.decoderV1 = TransformerDecoderV1(decoder_layerV1, 4, self.decoder_norm, |
|
return_intermediate=False) |
|
|
|
self.dropout_feas = nn.Dropout(0.1) |
|
|
|
|
|
self.mlp_head = nn.Sequential( |
|
nn.Linear(embed_dim, class_num) |
|
) |
|
self.apply(self._init_weights) |
|
|
|
@staticmethod |
|
def _init_weights(module): |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
|
|
elif isinstance(module, nn.MultiheadAttention): |
|
module.in_proj_weight.data.normal_(mean=0.0, std=0.02) |
|
module.out_proj.weight.data.normal_(mean=0.0, std=0.02) |
|
|
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, image_features, text_features, return_atten = False): |
|
|
|
|
|
batch_size = image_features.shape[0] |
|
image_features = image_features.transpose(0,1) |
|
text_features = text_features.unsqueeze(1).repeat(1, batch_size, 1) |
|
image_features = self.decoder_norm(image_features) |
|
text_features = self.decoder_norm(text_features) |
|
|
|
image_features_pool = torch.mean(image_features,dim=0).unsqueeze(0) |
|
features,atten_map = self.decoderV1(text_features, image_features, image_features_pool, |
|
memory_key_padding_mask=None, pos=None, query_pos=None) |
|
features = self.dropout_feas(features).transpose(0,1) |
|
out = self.mlp_head(features) |
|
if return_atten: |
|
return out, atten_map |
|
else: |
|
return out |
|
|
|
|
|
|
|
class TQN_Model_Ensemble(nn.Module): |
|
def __init__(self, |
|
embed_dim: int = 768, |
|
class_num: int = 1, |
|
lam: list = [1, 0] |
|
): |
|
super().__init__() |
|
self.d_model = embed_dim |
|
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) |
|
decoder_layerV1 = TransformerDecoderLayerV1(self.d_model, 4, 1024, |
|
0.1, 'relu', True, lam) |
|
self.decoder_norm = nn.LayerNorm(self.d_model) |
|
self.decoder_norm_1 = nn.LayerNorm(self.d_model) |
|
self.decoder_norm_2 = nn.LayerNorm(self.d_model) |
|
self.decoderV1 = TransformerDecoderV1(decoder_layerV1, 4, self.decoder_norm, |
|
return_intermediate=False) |
|
self.decoderV1_1 = TransformerDecoderV1(decoder_layerV1, 4, self.decoder_norm_1, |
|
return_intermediate=False) |
|
self.decoderV1_2 = TransformerDecoderV1(decoder_layerV1, 4, self.decoder_norm_2, |
|
return_intermediate=False) |
|
|
|
self.dropout_feas = nn.Dropout(0.1) |
|
|
|
|
|
self.mlp_head = nn.Sequential(nn.Linear(embed_dim, class_num)) |
|
self.mlp_head_1 = nn.Sequential(nn.Linear(embed_dim, class_num)) |
|
self.mlp_head_2 = nn.Sequential(nn.Linear(embed_dim, class_num)) |
|
self.apply(self._init_weights) |
|
|
|
@staticmethod |
|
def _init_weights(module): |
|
if isinstance(module, nn.Linear): |
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
|
|
elif isinstance(module, nn.MultiheadAttention): |
|
module.in_proj_weight.data.normal_(mean=0.0, std=0.02) |
|
module.out_proj.weight.data.normal_(mean=0.0, std=0.02) |
|
|
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
|
|
def forward(self, image_features, text_features, return_atten = False): |
|
|
|
batch_size = image_features.shape[0] |
|
image_features = image_features.transpose(0,1) |
|
text_features = text_features.unsqueeze(1).repeat(1, batch_size, 1) |
|
image_features = self.decoder_norm(image_features) |
|
image_features_1 = self.decoder_norm_1(image_features) |
|
image_features_2 = self.decoder_norm_2(image_features) |
|
|
|
text_features = self.decoder_norm(text_features) |
|
text_features_1 = self.decoder_norm_1(text_features) |
|
text_features_2 = self.decoder_norm_2(text_features) |
|
|
|
image_features_pool = torch.mean(image_features,dim=0).unsqueeze(0) |
|
image_features_pool_1 = torch.mean(image_features_1,dim=0).unsqueeze(0) |
|
image_features_pool_2 = torch.mean(image_features_2,dim=0).unsqueeze(0) |
|
|
|
|
|
features,atten_map = self.decoderV1(text_features, image_features, image_features_pool, |
|
memory_key_padding_mask=None, pos=None, query_pos=None) |
|
features = self.dropout_feas(features).transpose(0,1) |
|
out = self.mlp_head(features) |
|
|
|
features_1,atten_map_1 = self.decoderV1_1(text_features_1, image_features_1, image_features_pool_1, |
|
memory_key_padding_mask=None, pos=None, query_pos=None) |
|
features_1 = self.dropout_feas(features_1).transpose(0,1) |
|
out_1 = self.mlp_head_1(features_1) |
|
|
|
features_2,atten_map_2 = self.decoderV1_2(text_features_2, image_features_2, image_features_pool_2, |
|
memory_key_padding_mask=None, pos=None, query_pos=None) |
|
features_2 = self.dropout_feas(features_2).transpose(0,1) |
|
out_2 = self.mlp_head_2(features_2) |
|
|
|
|
|
out_stack = torch.stack([out, out_1, out_2]) |
|
out = torch.mean(out_stack, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
|
|
if return_atten: |
|
return out, atten_map |
|
else: |
|
return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img = torch.randn(2,3,224,224) |
|
|
|
|
|
model = ModelRes(res_base_model = 'resnet50_openai') |
|
out_emb, out_pool = model(img) |
|
|
|
print(out_emb.size(), out_pool.size()) |
|
|
|
|