File size: 2,871 Bytes
48ecfae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import numpy as np
import torch
import torch.nn.functional as F
from transformers import (
    PretrainedConfig,
    PreTrainedModel,
    SiglipVisionConfig,
    SiglipVisionModel,
    XLMRobertaConfig,
    XLMRobertaModel,
)


class MexmaSigLIPConfig(PretrainedConfig):
    def __init__(
        self,
        optimized: bool = False,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.optimized = optimized


class MexmaSigLIP(PreTrainedModel):
    config_class = MexmaSigLIPConfig

    def __init__(self, config: MexmaSigLIPConfig):
        super().__init__(config)
        self.config = config
        text_config = XLMRobertaConfig.from_pretrained("facebook/MEXMA")
        if self.config.optimized:
            text_config._attn_implementation = "sdpa"
        self.text_model = XLMRobertaModel(text_config, add_pooling_layer=False)
        self.text_projector = torch.nn.Linear(1024, 1152, bias=False)
        vision_congig = SiglipVisionConfig.from_pretrained(
            "google/siglip-so400m-patch14-384"
        )
        if self.config.optimized:
            vision_congig._attn_implementation = "flash_attention_2"
        self.vision_model = SiglipVisionModel(vision_congig).vision_model
        self.logit_scale = torch.nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
        self.logit_bias = torch.nn.Parameter(torch.ones([]) * -10)

    def forward(self, image_inputs, input_ids, attention_mask, normalize=False):
        text_features = self.encode_texts(input_ids, attention_mask, normalize)
        image_features = self.encode_images(image_inputs, normalize)
        return {
            "image_features": image_features,
            "text_features": text_features,
            "logit_scale": self.logit_scale,
            "logit_bias": self.logit_bias,
        }

    def encode_images(
        self,
        pixel_values,
        normalize=False,
    ):
        features = self.vision_model(pixel_values).pooler_output
        return F.normalize(features, dim=-1) if normalize else features

    def encode_texts(
        self,
        input_ids,
        attention_mask,
        normalize=False,
    ):
        features = self.text_model(
            input_ids=input_ids, attention_mask=attention_mask
        ).last_hidden_state[:, 0]
        features = self.text_projector(features)
        return F.normalize(features, dim=-1) if normalize else features

    def get_logits(
        self,
        input_ids,
        attention_mask,
        pixel_values,
    ):
        image_features = self.encode_images(pixel_values, normalize=True)
        text_features = self.encode_texts(input_ids, attention_mask, normalize=True)
        image_logits = (
            self.logit_scale.exp() * image_features @ text_features.T + self.logit_bias
        )
        text_logits = image_logits.T
        return image_logits, text_logits