TTP / mmseg /models /text_encoder /clip_text_encoder.py
KyanChen's picture
Upload 1861 files
3b96cb1
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import BaseTransformerLayer
from mmengine.model import BaseModule, ModuleList
from mmengine.runner.checkpoint import CheckpointLoader, load_state_dict
from torch.nn import functional as F
from mmseg.registry import MODELS
from mmseg.utils import get_classes, get_predefined_templates, tokenizer
@MODELS.register_module()
class CLIPTextEncoder(BaseModule):
"""A text encoder with transformer architecture to encode the label text.
Modified from https://github.com/MendelXu/SAN/blob/main/san/model/clip_utils/classifier.py # noqa:E501
Copyright (c) 2023 MendelXu.
Licensed under the MIT License
Args:
dataset_name: (str|None): The name of the dataset to which
the data belongs.
vocabulary: (List[str]|None): The list of class names. Default: None.
templates: (List[str]|None): The prompt template used for labels.
Default: None.
total_vocab_size: (int): Number of all words used by the pre-trained
model. Default: 49408 (CLIP).
context_length: (int): The max length of prompt text.
Default: 77 (CLIP).
embed_dims: (int): Width of transformer model. Default: 512.
num_layers: (int): Depth of transformer. Default: 12,
num_heads: (int): Number of attention heads in transformer.
Default: 8,
mlp_ratio: (int) Ratio of mlp hidden dim to embedding dim in
transformer. Default: 4,
output_dims: (int) Dim of output text embeddings. Default: 512,
cache_feature: (bool) Whether to save class embeddings in cache.
Default: True,
cat_bg: (bool) Whether to add background embedding. Default: True.
norm_cfg (dict|None): Config for norm layer. Default: dict(type='LN')
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
dataset_name: str = None,
vocabulary: List[str] = None,
templates: str = 'vild',
total_vocab_size: int = 49408,
context_length: int = 77,
embed_dims: int = 512,
num_layers: int = 12,
num_heads: int = 8,
mlp_ratio: int = 4,
output_dims: int = 512,
cache_feature: bool = True,
cat_bg: bool = True,
norm_cfg: dict = dict(type='LN'),
init_cfg: dict = None):
super().__init__(init_cfg)
if isinstance(templates, List):
self.templates = templates
else:
self.templates = get_predefined_templates(templates)
assert dataset_name is not None or vocabulary is not None, \
"text_encoder required either 'dataset_name' or 'vocabulary'"
assert dataset_name is None or vocabulary is None, \
"there is conflict between 'dataset_name' and 'vocabulary'"
self.dataset_name = dataset_name
self.vocabulary = vocabulary
self.num_pos = context_length
self.token_embedding = nn.Embedding(total_vocab_size, embed_dims)
self.positional_embedding = nn.Parameter(
torch.empty(context_length, embed_dims))
self.text_projection = nn.Parameter(
torch.empty(embed_dims, output_dims))
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
self.transformer = ModuleList()
self.register_buffer(
'attn_mask', self.build_attention_mask(), persistent=False)
for i in range(num_layers):
self.transformer.append(
BaseTransformerLayer(
attn_cfgs=dict(
type='MultiheadAttention',
embed_dims=embed_dims,
num_heads=num_heads,
batch_first=False,
bias=True),
ffn_cfgs=dict(
type='FFN',
embed_dims=embed_dims,
feedforward_channels=mlp_ratio * embed_dims,
act_cfg=dict(type='QuickGELU')),
operation_order=('norm', 'self_attn', 'norm', 'ffn')))
self.ln_final = build_norm_layer(
norm_cfg, embed_dims, postfix='_final')[1]
self.cache_feature = cache_feature
if self.cache_feature:
self.cache = {}
self._freeze()
self.cat_bg = cat_bg
if self.cat_bg:
self.bg_embed = nn.Parameter(
torch.randn(1, self.text_projection.shape[1]))
@property
def ln_final(self):
return getattr(self, self.final_name)
def build_attention_mask(self):
"""lazily create causal attention mask, with full attention between the
tokens.
pytorch uses additive attention mask; fill with -inf
"""
mask = torch.empty(self.num_pos, self.num_pos)
mask.fill_(float('-inf'))
mask.triu_(1) # zero out the lower diagonal
return mask
def _freeze(self):
for param in self.parameters():
param.requires_grad = False
def init_weights(self):
if self.cat_bg:
nn.init.normal_(
self.bg_embed,
std=self.bg_embed.shape[1]**-0.5,
)
if isinstance(self.init_cfg, dict) and \
self.init_cfg.get('type') == 'Pretrained_Part':
checkpoint = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], logger=None, map_location='cpu')
state_dict = checkpoint.copy()
para_prefix = 'text_encoder'
prefix_len = len(para_prefix) + 1
for k, v in checkpoint.items():
state_dict.pop(k)
if para_prefix in k:
state_dict[k[prefix_len:]] = v
load_state_dict(self, state_dict, strict=False, logger=None)
else:
super().init_weights()
@torch.no_grad()
def encode_text(self, text, normalize=False):
"""encode class token."""
embed_device = self.token_embedding.weight.device
x = self.token_embedding(
text.to(embed_device)) # [batch_size, n_ctx, d_model]
x = x + self.positional_embedding
x = x.permute(1, 0, 2) # NLD -> LND
for block in self.transformer:
x = block(query=x, attn_masks=self.attn_mask)
x = x.permute(1, 0, 2) # LND -> NLD
x = self.ln_final(x) # [batch_size, n_ctx, transformer.width]
# take features from the eot embedding
# (eot_token is the highest number in each sequence)
x = x[torch.arange(x.shape[0]),
text.argmax(dim=-1)] @ self.text_projection
return F.normalize(x, dim=-1) if normalize else x
def template_encode(self, vocabulary):
"""Prompt engineering."""
text_embed_bucket = []
for template in self.templates:
text_inputs = tokenizer.tokenize(
[template.format(noun) for noun in vocabulary])
text_embed = self.encode_text(text_inputs, normalize=True)
text_embed_bucket.append(text_embed)
text_embed = torch.stack(text_embed_bucket).mean(dim=0)
text_embed = text_embed / text_embed.norm(dim=-1, keepdim=True)
return text_embed
def forward(self):
"""Forward function."""
if self.dataset_name is None: # encoding vocabulary directly
class_names = self.vocabulary
if self.cache_feature:
new_classes = [
word for word in class_names if word not in self.cache
]
if len(new_classes) > 0:
class_embeds = self.template_encode(new_classes)
self.cache.update(dict(zip(new_classes, class_embeds)))
class_embeds = torch.stack(
[self.cache[word] for word in class_names])
else:
class_embeds = self.template_encode(class_names)
else: # encoding the classes of the dataset
class_names = get_classes(self.dataset_name)
if class_names[0] == 'background':
class_names = class_names[1:]
if self.cache_feature:
if self.dataset_name not in self.cache:
class_embeds = self.template_encode(class_names)
self.cache[self.dataset_name] = class_embeds
else:
class_embeds = self.cache[self.dataset_name]
else:
class_embeds = self.template_encode(class_names)
if self.cat_bg:
class_embeds = torch.cat([class_embeds, self.bg_embed])
class_embeds = F.normalize(class_embeds, p=2, dim=-1)
return self.logit_scale.exp() * class_embeds
@MODELS.register_module()
class QuickGELU(nn.Module):
# From https://github.com/openai/CLIP/blob/main/clip/model.py
def forward(self, x: torch.Tensor):
return x * torch.sigmoid(1.702 * x)