File size: 4,699 Bytes
9ca9564 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 |
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from transformers import PreTrainedModel, VisionTextDualEncoderConfig, VisionTextDualEncoderModel
from transformers.models.vision_text_dual_encoder.modeling_vision_text_dual_encoder import clip_loss, CLIPOutput
class MeanPooler(nn.Module):
"""Mean pooling"""
def forward(self, x, attention_mask):
masked_output = x.last_hidden_state * attention_mask.unsqueeze(-1)
return masked_output.sum(dim=1) / attention_mask.sum(-1, keepdim=True)
class OpenCLIPVisionTextDualEncoderModel(VisionTextDualEncoderModel):
def __init__(
self,
config: Optional[VisionTextDualEncoderConfig] = None,
vision_model: Optional[PreTrainedModel] = None,
text_model: Optional[PreTrainedModel] = None,
add_text_model_pooling_layer: bool = False,
):
super().__init__(config, vision_model, text_model)
# Remove text pooling layer
if not add_text_model_pooling_layer:
self.text_model.pooler = None
# Add mean pooling
self.pooler = MeanPooler()
# Overwrite text projection
hidden_size = (self.text_embed_dim + self.projection_dim) // 2
self.text_projection = nn.Sequential(
nn.Linear(self.text_embed_dim, hidden_size, bias=False),
nn.GELU(),
nn.Linear(hidden_size, self.projection_dim, bias=False),
)
def get_text_features(
self,
input_ids=None,
attention_mask=None,
position_ids=None,
token_type_ids=None,
output_attentions=None,
output_hidden_states=None,
return_dict=None,
):
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
token_type_ids=token_type_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
pooled_output = self.pooler(text_outputs, attention_mask)
text_features = self.text_projection(pooled_output)
return text_features
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
pixel_values: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
return_loss: Optional[bool] = None,
token_type_ids: Optional[torch.LongTensor] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], CLIPOutput]:
return_dict = return_dict if return_dict is not None else self.config.return_dict
vision_outputs = self.vision_model(
pixel_values=pixel_values,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
text_outputs = self.text_model(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
image_embeds = vision_outputs[1] # pooler_output
image_embeds = self.visual_projection(image_embeds)
pooled_output = self.pooler(text_outputs, attention_mask)
text_embeds = self.text_projection(pooled_output)
# normalized features
image_embeds = image_embeds / image_embeds.norm(dim=-1, keepdim=True)
text_embeds = text_embeds / text_embeds.norm(dim=-1, keepdim=True)
# cosine similarity as logits
logit_scale = self.logit_scale.exp()
logits_per_text = torch.matmul(text_embeds, image_embeds.t()) * logit_scale
logits_per_image = logits_per_text.T
loss = None
if return_loss:
loss = clip_loss(logits_per_text)
if not return_dict:
output = (logits_per_image, logits_per_text, text_embeds, image_embeds, text_outputs, vision_outputs)
return ((loss,) + output) if loss is not None else output
return CLIPOutput(
loss=loss,
logits_per_image=logits_per_image,
logits_per_text=logits_per_text,
text_embeds=text_embeds,
image_embeds=image_embeds,
text_model_output=text_outputs,
vision_model_output=vision_outputs,
)
|