Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import List | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from mmcv.cnn import build_norm_layer | |
from mmcv.cnn.bricks.transformer import BaseTransformerLayer | |
from mmengine.model import BaseModule, ModuleList | |
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict | |
from torch.nn import functional as F | |
from mmseg.registry import MODELS | |
from mmseg.utils import get_classes, get_predefined_templates, tokenizer | |
class CLIPTextEncoder(BaseModule): | |
"""A text encoder with transformer architecture to encode the label text. | |
Modified from https://github.com/MendelXu/SAN/blob/main/san/model/clip_utils/classifier.py # noqa:E501 | |
Copyright (c) 2023 MendelXu. | |
Licensed under the MIT License | |
Args: | |
dataset_name: (str|None): The name of the dataset to which | |
the data belongs. | |
vocabulary: (List[str]|None): The list of class names. Default: None. | |
templates: (List[str]|None): The prompt template used for labels. | |
Default: None. | |
total_vocab_size: (int): Number of all words used by the pre-trained | |
model. Default: 49408 (CLIP). | |
context_length: (int): The max length of prompt text. | |
Default: 77 (CLIP). | |
embed_dims: (int): Width of transformer model. Default: 512. | |
num_layers: (int): Depth of transformer. Default: 12, | |
num_heads: (int): Number of attention heads in transformer. | |
Default: 8, | |
mlp_ratio: (int) Ratio of mlp hidden dim to embedding dim in | |
transformer. Default: 4, | |
output_dims: (int) Dim of output text embeddings. Default: 512, | |
cache_feature: (bool) Whether to save class embeddings in cache. | |
Default: True, | |
cat_bg: (bool) Whether to add background embedding. Default: True. | |
norm_cfg (dict|None): Config for norm layer. Default: dict(type='LN') | |
init_cfg (dict or list[dict], optional): Initialization config dict. | |
Default: None. | |
""" | |
def __init__(self, | |
dataset_name: str = None, | |
vocabulary: List[str] = None, | |
templates: str = 'vild', | |
total_vocab_size: int = 49408, | |
context_length: int = 77, | |
embed_dims: int = 512, | |
num_layers: int = 12, | |
num_heads: int = 8, | |
mlp_ratio: int = 4, | |
output_dims: int = 512, | |
cache_feature: bool = True, | |
cat_bg: bool = True, | |
norm_cfg: dict = dict(type='LN'), | |
init_cfg: dict = None): | |
super().__init__(init_cfg) | |
if isinstance(templates, List): | |
self.templates = templates | |
else: | |
self.templates = get_predefined_templates(templates) | |
assert dataset_name is not None or vocabulary is not None, \ | |
"text_encoder required either 'dataset_name' or 'vocabulary'" | |
assert dataset_name is None or vocabulary is None, \ | |
"there is conflict between 'dataset_name' and 'vocabulary'" | |
self.dataset_name = dataset_name | |
self.vocabulary = vocabulary | |
self.num_pos = context_length | |
self.token_embedding = nn.Embedding(total_vocab_size, embed_dims) | |
self.positional_embedding = nn.Parameter( | |
torch.empty(context_length, embed_dims)) | |
self.text_projection = nn.Parameter( | |
torch.empty(embed_dims, output_dims)) | |
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) | |
self.transformer = ModuleList() | |
self.register_buffer( | |
'attn_mask', self.build_attention_mask(), persistent=False) | |
for i in range(num_layers): | |
self.transformer.append( | |
BaseTransformerLayer( | |
attn_cfgs=dict( | |
type='MultiheadAttention', | |
embed_dims=embed_dims, | |
num_heads=num_heads, | |
batch_first=False, | |
bias=True), | |
ffn_cfgs=dict( | |
type='FFN', | |
embed_dims=embed_dims, | |
feedforward_channels=mlp_ratio * embed_dims, | |
act_cfg=dict(type='QuickGELU')), | |
operation_order=('norm', 'self_attn', 'norm', 'ffn'))) | |
self.ln_final = build_norm_layer( | |
norm_cfg, embed_dims, postfix='_final')[1] | |
self.cache_feature = cache_feature | |
if self.cache_feature: | |
self.cache = {} | |
self._freeze() | |
self.cat_bg = cat_bg | |
if self.cat_bg: | |
self.bg_embed = nn.Parameter( | |
torch.randn(1, self.text_projection.shape[1])) | |
def ln_final(self): | |
return getattr(self, self.final_name) | |
def build_attention_mask(self): | |
"""lazily create causal attention mask, with full attention between the | |
tokens. | |
pytorch uses additive attention mask; fill with -inf | |
""" | |
mask = torch.empty(self.num_pos, self.num_pos) | |
mask.fill_(float('-inf')) | |
mask.triu_(1) # zero out the lower diagonal | |
return mask | |
def _freeze(self): | |
for param in self.parameters(): | |
param.requires_grad = False | |
def init_weights(self): | |
if self.cat_bg: | |
nn.init.normal_( | |
self.bg_embed, | |
std=self.bg_embed.shape[1]**-0.5, | |
) | |
if isinstance(self.init_cfg, dict) and \ | |
self.init_cfg.get('type') == 'Pretrained_Part': | |
checkpoint = CheckpointLoader.load_checkpoint( | |
self.init_cfg['checkpoint'], logger=None, map_location='cpu') | |
state_dict = checkpoint.copy() | |
para_prefix = 'text_encoder' | |
prefix_len = len(para_prefix) + 1 | |
for k, v in checkpoint.items(): | |
state_dict.pop(k) | |
if para_prefix in k: | |
state_dict[k[prefix_len:]] = v | |
load_state_dict(self, state_dict, strict=False, logger=None) | |
else: | |
super().init_weights() | |
def encode_text(self, text, normalize=False): | |
"""encode class token.""" | |
embed_device = self.token_embedding.weight.device | |
x = self.token_embedding( | |
text.to(embed_device)) # [batch_size, n_ctx, d_model] | |
x = x + self.positional_embedding | |
x = x.permute(1, 0, 2) # NLD -> LND | |
for block in self.transformer: | |
x = block(query=x, attn_masks=self.attn_mask) | |
x = x.permute(1, 0, 2) # LND -> NLD | |
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] | |
# take features from the eot embedding | |
# (eot_token is the highest number in each sequence) | |
x = x[torch.arange(x.shape[0]), | |
text.argmax(dim=-1)] @ self.text_projection | |
return F.normalize(x, dim=-1) if normalize else x | |
def template_encode(self, vocabulary): | |
"""Prompt engineering.""" | |
text_embed_bucket = [] | |
for template in self.templates: | |
text_inputs = tokenizer.tokenize( | |
[template.format(noun) for noun in vocabulary]) | |
text_embed = self.encode_text(text_inputs, normalize=True) | |
text_embed_bucket.append(text_embed) | |
text_embed = torch.stack(text_embed_bucket).mean(dim=0) | |
text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True) | |
return text_embed | |
def forward(self): | |
"""Forward function.""" | |
if self.dataset_name is None: # encoding vocabulary directly | |
class_names = self.vocabulary | |
if self.cache_feature: | |
new_classes = [ | |
word for word in class_names if word not in self.cache | |
] | |
if len(new_classes) > 0: | |
class_embeds = self.template_encode(new_classes) | |
self.cache.update(dict(zip(new_classes, class_embeds))) | |
class_embeds = torch.stack( | |
[self.cache[word] for word in class_names]) | |
else: | |
class_embeds = self.template_encode(class_names) | |
else: # encoding the classes of the dataset | |
class_names = get_classes(self.dataset_name) | |
if class_names[0] == 'background': | |
class_names = class_names[1:] | |
if self.cache_feature: | |
if self.dataset_name not in self.cache: | |
class_embeds = self.template_encode(class_names) | |
self.cache[self.dataset_name] = class_embeds | |
else: | |
class_embeds = self.cache[self.dataset_name] | |
else: | |
class_embeds = self.template_encode(class_names) | |
if self.cat_bg: | |
class_embeds = torch.cat([class_embeds, self.bg_embed]) | |
class_embeds = F.normalize(class_embeds, p=2, dim=-1) | |
return self.logit_scale.exp() * class_embeds | |
class QuickGELU(nn.Module): | |
# From https://github.com/openai/CLIP/blob/main/clip/model.py | |
def forward(self, x: torch.Tensor): | |
return x * torch.sigmoid(1.702 * x) | |