import sys
# 加入父文件夹路径到sys.path中
sys.path.append(sys.path[0].replace('models', ''))
import re
import logging
import math
import json
import pathlib
import numpy as np
from copy import deepcopy
from pathlib import Path
from einops import rearrange
from collections import OrderedDict
from dataclasses import dataclass
from typing import Tuple, Union, Callable, Optional
import torch
from torch import nn
import torch.nn.functional as F
import torchvision.models as models
from torch.utils.checkpoint import checkpoint
from transformers import AutoModel,BertConfig,AutoTokenizer
from models.transformer_decoder import *
from torch.autograd import Function
import timm
class ReverseLayerF(Function):
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
def backward(ctx, grad_output):
return grad_output.neg() * ctx.alpha, None
class DomainClassifier(nn.Module):
'''一个单层分类器 带梯度反转层'''
def __init__(self, domain_nums=4, feature_dims=768):
self.domain_nums = domain_nums
self.feature_dims = feature_dims
self.fc = nn.Linear(feature_dims, domain_nums)
def forward(self, x):
reverse_x = ReverseLayerF.apply(x, 1.0)
return self.fc(reverse_x)
class CLP_clinical(nn.Module):
def __init__(self,
bert_model_name: str,
embed_dim: int = 768,
freeze_layers:Union[Tuple[int, int], int] = None):
self.bert_model = self._get_bert_basemodel(bert_model_name=bert_model_name, freeze_layers=freeze_layers)
self.mlp_embed = nn.Sequential(
nn.Linear(embed_dim, embed_dim),
nn.Linear(embed_dim, embed_dim)
self.embed_dim = embed_dim
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def init_parameters(self):
nn.init.constant_(self.logit_scale, np.log(1 / 0.07))
for m in self.mlp_embed:
if isinstance(m, nn.Linear):
nn.init.normal_(m.weight, std=self.embed_dim ** -0.5)
def _get_bert_basemodel(self, bert_model_name, freeze_layers=None):#12
config = BertConfig.from_pretrained(bert_model_name, output_hidden_states=True)#bert-base-uncased
model = AutoModel.from_pretrained(bert_model_name, config=config)#, return_dict=True)
print("Text feature extractor:", bert_model_name)
print("bert encoder layers:",len(model.encoder.layer))
raise ("Invalid model name. Check the config file and pass a BERT model from transformers lybrary")
if freeze_layers is not None:
for layer_idx in freeze_layers:
for param in list(model.encoder.layer[layer_idx].parameters()):
param.requires_grad = False
return model
def encode_text(self, text):
#input batch_size,token, return batch_size,dim
output = self.bert_model(input_ids = text['input_ids'],attention_mask = text['attention_mask'] )
last_hidden_state, pooler_output, hidden_states = output[0],output[1],output[2]
encode_out = self.mlp_embed(pooler_output)
return encode_out
def forward(self, text):
#input batch_size,token, return batch_size,dim
output = self.bert_model(input_ids = text['input_ids'],attention_mask = text['attention_mask'] )
last_hidden_state, pooler_output, hidden_states = output[0],output[1],output[2]
encode_out = self.mlp_embed(pooler_output)
return encode_out
class ModelRes(nn.Module):
def __init__(self, res_base_model):
super(ModelRes, self).__init__()
self.resnet_dict = {
"resnet50": models.resnet50(pretrained=True),
"resnet101": models.resnet101(pretrained=True),
"resnet152": models.resnet152(pretrained=True),
"resnet50_openai": None,
'resnet101_openai': None,
'resnet50x4_openai': None,
self.resnet = self._get_res_basemodel(res_base_model)
if 'openai' in res_base_model:
# 重新定义res_features
num_ftrs = int(self.resnet.attnpool.v_proj.in_features)
self.res_features = nn.Sequential(*list(self.resnet.children())[:-1])
num_ftrs = int(self.resnet.fc.in_features)
self.res_features = nn.Sequential(*list(self.resnet.children())[:-2])
# here num_ftrs = 2048
self.res_l1 = nn.Linear(num_ftrs, num_ftrs)
self.res_l2 = nn.Linear(num_ftrs, 768)
def _get_res_basemodel(self, res_model_name):
res_model = self.resnet_dict[res_model_name]
print("Image feature extractor:", res_model_name)
return res_model
raise ("Invalid model name. Check the config file and pass one of: resnet18 or resnet50")
def forward(self, img):
#return (batchsize, patch_num, dim)
batch_size = img.shape[0]
res_fea = self.res_features(img)
res_fea = rearrange(res_fea,'b d n1 n2 -> b (n1 n2) d')
h = rearrange(res_fea,'b n d -> (b n) d')
x = self.res_l1(h)
x = F.relu(x)
x = self.res_l2(x)
out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size)
out_pool = torch.mean(out_emb,dim=1)
return out_emb,out_pool
class ModelConvNeXt(nn.Module):
def __init__(self, convnext_base_model):
super(ModelConvNeXt, self).__init__()
self.convnext_dict = {"convnext-tiny": timm.create_model('convnextv2_tiny.fcmae_ft_in22k_in1k_384', pretrained=True, num_classes=1000),
"convnext-base": timm.create_model('convnextv2_base.fcmae_ft_in22k_in1k_384', pretrained=True, num_classes=1000),
convnext = self._get_convnext_basemodel(convnext_base_model)
num_ftrs = int(convnext.head.in_features)
self.conv_features = nn.Sequential(*list(convnext.children())[:-2])
self.conv_l1 = nn.Linear(num_ftrs, num_ftrs)
self.conv_l2 = nn.Linear(num_ftrs, 768)
def _get_convnext_basemodel(self, convnext_model_name):
convnext_model = self.convnext_dict[convnext_model_name]
print("Image feature extractor:", convnext_model_name)
return convnext_model
raise ("Invalid model name. Check the config file and pass one of: convnext-tiny, convnext-small or convnext-base")
def forward(self, img):
#return (batchsize, patch_num, dim)
batch_size = img.shape[0]
conv_fea = self.conv_features(img)
conv_fea = F.adaptive_avg_pool2d(conv_fea, (1, 1))
conv_fea = rearrange(conv_fea,'b d n1 n2 -> b (n1 n2) d')
h = rearrange(conv_fea,'b n d -> (b n) d')
x = self.conv_l1(h)
x = F.relu(x)
x = self.conv_l2(x)
out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size)
out_pool = torch.mean(out_emb,dim=1)
return out_emb,out_pool
class ModelEfficientV2(nn.Module):
def __init__(self, efficientv2_base_model):
super(ModelEfficientV2, self).__init__()
self.efficientv2_dict = {"efficientnet_v2_s": models.efficientnet_v2_s(weights='EfficientNet_V2_S_Weights.IMAGENET1K_V1'),}
self.efficientv2_model = self._get_efficientv2_basemodel(efficientv2_base_model)
num_ftrs = int(self.efficientv2_model.classifier[-1].in_features)
self.efficientv2_features = nn.Sequential(*list(self.efficientv2_model.children())[:-2])
self.efficientv2_l1 = nn.Linear(num_ftrs, num_ftrs)
self.efficientv2_l2 = nn.Linear(num_ftrs, 768)
def _get_efficientv2_basemodel(self, efficientv2_model_name):
efficientv2_model = self.efficientv2_dict[efficientv2_model_name]
print("Image feature extractor:", efficientv2_model_name)
return efficientv2_model
raise ("Invalid model name. Check the config file and pass one of: efficientnetv2_rw_s")
def forward(self, img):
batch_size = img.shape[0]
efficientv2_fea = self.efficientv2_features(img)
efficientv2_fea = rearrange(efficientv2_fea,'b d n1 n2 -> b (n1 n2) d')
h = rearrange(efficientv2_fea,'b n d -> (b n) d')
x = self.efficientv2_l1(h)
x = F.relu(x)
x = self.efficientv2_l2(x)
out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size)
out_pool = torch.mean(out_emb,dim=1)
return out_emb,out_pool
class ModelDense(nn.Module):
def __init__(self, dense_base_model):
super(ModelDense, self).__init__()
self.densenet_dict = {"densenet121": models.densenet121(weights='DenseNet121_Weights.IMAGENET1K_V1'),
"densenet161": models.densenet161(weights='DenseNet161_Weights.IMAGENET1K_V1'),
"densenet201": models.densenet201(weights='DenseNet201_Weights.IMAGENET1K_V1'),}
self.densenet = self._get_dense_basemodel(dense_base_model)
num_ftrs = int(self.densenet.classifier.in_features)
self.dense_features = self.densenet.features
self.dense_l1 = nn.Linear(num_ftrs, num_ftrs)
self.dense_l2 = nn.Linear(num_ftrs, 768)
def _get_dense_basemodel(self, dense_base_model):
dense_model = self.densenet_dict[dense_base_model]
print("Image feature extractor:", dense_base_model)
return dense_model
raise ("Invalid model name. Check the config file and pass one of: densenet121 or densenet161")
def forward(self, img):
batch_size = img.shape[0]
dense_fea = self.dense_features(img)#N, 1024, 7,7
dense_fea = rearrange(dense_fea,'b d n1 n2 -> b (n1 n2) d')
h = rearrange(dense_fea,'b n d -> (b n) d')
x = self.dense_l1(h)
x = F.relu(x)
x = self.dense_l2(x)
out_emb = rearrange(x,'(b n) d -> b n d',b=batch_size)
out_pool = torch.mean(out_emb,dim=1)
return out_emb,out_pool
class TQN_Model(nn.Module):
def __init__(self,
embed_dim: int = 768,
class_num: int = 1,
lam: list = [1, 0]
self.d_model = embed_dim
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
decoder_layerV1 = TransformerDecoderLayerV1(self.d_model, 4, 1024,
0.1, 'relu', True, lam)
self.decoder_norm = nn.LayerNorm(self.d_model)
self.decoderV1 = TransformerDecoderV1(decoder_layerV1, 4, self.decoder_norm,
self.dropout_feas = nn.Dropout(0.1)
self.mlp_head = nn.Sequential( # nn.LayerNorm(768),
nn.Linear(embed_dim, class_num)
def _init_weights(module):
if isinstance(module, nn.Linear):, std=0.02)
elif isinstance(module, nn.MultiheadAttention):, std=0.02), std=0.02)
elif isinstance(module, nn.Embedding):, std=0.02)
if module.padding_idx is not None:[module.padding_idx].zero_()
def forward(self, image_features, text_features, return_atten = False):
#image_features (batch_size,patch_num,dim)
#text_features (query_num,dim)
batch_size = image_features.shape[0]
image_features = image_features.transpose(0,1)
text_features = text_features.unsqueeze(1).repeat(1, batch_size, 1)
image_features = self.decoder_norm(image_features)
text_features = self.decoder_norm(text_features)
image_features_pool = torch.mean(image_features,dim=0).unsqueeze(0)
features,atten_map = self.decoderV1(text_features, image_features, image_features_pool,
memory_key_padding_mask=None, pos=None, query_pos=None)
features = self.dropout_feas(features).transpose(0,1) #b,embed_dim
out = self.mlp_head(features) #(batch_size, query_num)
if return_atten:
return out, atten_map
return out
class TQN_Model_Ensemble(nn.Module):
def __init__(self,
embed_dim: int = 768,
class_num: int = 1,
lam: list = [1, 0]
self.d_model = embed_dim
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
decoder_layerV1 = TransformerDecoderLayerV1(self.d_model, 4, 1024,
0.1, 'relu', True, lam)
self.decoder_norm = nn.LayerNorm(self.d_model)
self.decoder_norm_1 = nn.LayerNorm(self.d_model)
self.decoder_norm_2 = nn.LayerNorm(self.d_model)
self.decoderV1 = TransformerDecoderV1(decoder_layerV1, 4, self.decoder_norm,
self.decoderV1_1 = TransformerDecoderV1(decoder_layerV1, 4, self.decoder_norm_1,
self.decoderV1_2 = TransformerDecoderV1(decoder_layerV1, 4, self.decoder_norm_2,
self.dropout_feas = nn.Dropout(0.1)
self.mlp_head = nn.Sequential(nn.Linear(embed_dim, class_num))
self.mlp_head_1 = nn.Sequential(nn.Linear(embed_dim, class_num))
self.mlp_head_2 = nn.Sequential(nn.Linear(embed_dim, class_num))
def _init_weights(module):
if isinstance(module, nn.Linear):, std=0.02)
elif isinstance(module, nn.MultiheadAttention):, std=0.02), std=0.02)
elif isinstance(module, nn.Embedding):, std=0.02)
if module.padding_idx is not None:[module.padding_idx].zero_()
def forward(self, image_features, text_features, return_atten = False):
batch_size = image_features.shape[0]
image_features = image_features.transpose(0,1)
text_features = text_features.unsqueeze(1).repeat(1, batch_size, 1)
image_features = self.decoder_norm(image_features)
image_features_1 = self.decoder_norm_1(image_features)
image_features_2 = self.decoder_norm_2(image_features)
text_features = self.decoder_norm(text_features)
text_features_1 = self.decoder_norm_1(text_features)
text_features_2 = self.decoder_norm_2(text_features)
image_features_pool = torch.mean(image_features,dim=0).unsqueeze(0)
image_features_pool_1 = torch.mean(image_features_1,dim=0).unsqueeze(0)
image_features_pool_2 = torch.mean(image_features_2,dim=0).unsqueeze(0)
features,atten_map = self.decoderV1(text_features, image_features, image_features_pool,
memory_key_padding_mask=None, pos=None, query_pos=None)
features = self.dropout_feas(features).transpose(0,1)
out = self.mlp_head(features)
features_1,atten_map_1 = self.decoderV1_1(text_features_1, image_features_1, image_features_pool_1,
memory_key_padding_mask=None, pos=None, query_pos=None)
features_1 = self.dropout_feas(features_1).transpose(0,1)
out_1 = self.mlp_head_1(features_1)
features_2,atten_map_2 = self.decoderV1_2(text_features_2, image_features_2, image_features_pool_2,
memory_key_padding_mask=None, pos=None, query_pos=None)
features_2 = self.dropout_feas(features_2).transpose(0,1)
out_2 = self.mlp_head_2(features_2)
out_stack = torch.stack([out, out_1, out_2])
out = torch.mean(out_stack, dim=0)
if return_atten:
return out, atten_map
return out
if __name__ == "__main__":
img = torch.randn(2,3,224,224)
model = ModelRes(res_base_model = 'resnet50_openai')
out_emb, out_pool = model(img)
print(out_emb.size(), out_pool.size())