DarthReca commited on
Commit
3c0c894
·
verified ·
1 Parent(s): 30f4336

Create modeling_closp.py

Browse files
Files changed (1) hide show
  1. modeling_closp.py +202 -0
modeling_closp.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from timm import create_model
7
+ from transformers import (
8
+ AutoConfig,
9
+ AutoModel,
10
+ AutoTokenizer,
11
+ PretrainedConfig,
12
+ PreTrainedModel,
13
+ )
14
+ from transformers.utils import ModelOutput
15
+
16
+ from .location_encoder import LocationEncoder
17
+
18
+
19
+ class CLOSPConfig(PretrainedConfig):
20
+ """
21
+ Configuration class for CLOSPModel.
22
+
23
+ This class stores the configuration of a CLOSPModel, which is used to instantiate the model
24
+ according to the specified parameters.
25
+ """
26
+
27
+ model_type = "closp"
28
+
29
+ def __init__(
30
+ self,
31
+ # Vision model parameters
32
+ vision_model_key: str = "vit-s",
33
+ s1_embedding_dim: int = 384,
34
+ s2_embedding_dim: int = 384,
35
+ s1_head_dim: int = 0,
36
+ s2_head_dim: int = 0,
37
+ # Text model parameters
38
+ text_model_name_or_path: str = "distilbert-base-uncased",
39
+ # Location encoder parameters (optional)
40
+ use_location_encoder: bool = True,
41
+ location_embedding_dim: int = 512,
42
+ # General model parameters
43
+ projection_dim: int = 768,
44
+ **kwargs,
45
+ ):
46
+ super().__init__(**kwargs)
47
+ self.vision_model_key = vision_model_key
48
+ self.s1_embedding_dim = s1_embedding_dim
49
+ self.s2_embedding_dim = s2_embedding_dim
50
+ self.text_model_name_or_path = text_model_name_or_path
51
+ self.use_location_encoder = use_location_encoder
52
+ self.location_embedding_dim = location_embedding_dim
53
+ self.projection_dim = projection_dim
54
+ self.s1_head_dim = s1_head_dim
55
+ self.s2_head_dim = s2_head_dim
56
+
57
+
58
+ # --- Structured Model Output ---
59
+ @dataclass
60
+ class CLOSPOutput(ModelOutput):
61
+ """
62
+ Base class for CLOSP model's outputs.
63
+ """
64
+
65
+ loss: torch.FloatTensor = None
66
+ logits_per_image: torch.FloatTensor = None
67
+ logits_per_text: torch.FloatTensor = None
68
+ logits_per_loc_img: torch.FloatTensor = None
69
+ logits_per_img_loc: torch.FloatTensor = None
70
+ image_embeds: torch.FloatTensor = None
71
+ text_embeds: torch.FloatTensor = None
72
+ location_embeds: torch.FloatTensor = None
73
+
74
+
75
+ class CLOSPModel(PreTrainedModel):
76
+ config_class = CLOSPConfig
77
+
78
+ def __init__(self, config: CLOSPConfig):
79
+ super().__init__(config)
80
+ # --- Vision Encoders ---
81
+ self.s1_encoder = create_model(
82
+ config.vision_model_key,
83
+ in_chans=2,
84
+ num_classes=config.s1_head_dim,
85
+ pretrained=False,
86
+ )
87
+ self.s2_encoder = create_model(
88
+ config.vision_model_key,
89
+ in_chans=13,
90
+ num_classes=config.s2_head_dim,
91
+ pretrained=False,
92
+ )
93
+ self.s1_projection = nn.Linear(config.s1_embedding_dim, config.projection_dim)
94
+ self.s2_projection = nn.Linear(config.s2_embedding_dim, config.projection_dim)
95
+
96
+ # --- Text Encoder ---
97
+ self.text_model = AutoModel.from_config(
98
+ AutoConfig.from_pretrained(config.text_model_name_or_path)
99
+ )
100
+ self.tokenizer = AutoTokenizer.from_pretrained(config.text_model_name_or_path)
101
+
102
+ # --- Location Encoder ---
103
+ if config.use_location_encoder:
104
+ self.location_encoder = LocationEncoder(512, 2, 256, 10)
105
+ self.location_projection = nn.Linear(
106
+ config.location_embedding_dim, config.projection_dim
107
+ )
108
+
109
+ def tokenize_text(self, text: str):
110
+ """Tokenizes input text using the model's tokenizer."""
111
+ return self.tokenizer(
112
+ text,
113
+ padding="max_length",
114
+ truncation=True,
115
+ max_length=self.tokenizer.model_max_length,
116
+ return_tensors="pt",
117
+ )
118
+
119
+ def get_image_features(self, image: torch.Tensor) -> torch.Tensor:
120
+ """Encodes an image tensor into features."""
121
+ image = image.float()
122
+ if image.shape[1] == 2: # Sentinel-1
123
+ image_features = self.s1_projection(self.s1_encoder(image))
124
+ else: # Sentinel-2
125
+ image_features = self.s2_projection(self.s2_encoder(image))
126
+
127
+ return F.normalize(image_features, p=2, dim=-1)
128
+
129
+ def get_text_features(
130
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor
131
+ ) -> torch.Tensor:
132
+ """Encodes text tokens into features."""
133
+ text_outputs = self.text_model(
134
+ input_ids=input_ids,
135
+ attention_mask=attention_mask,
136
+ output_hidden_states=True,
137
+ )
138
+ text_features = text_outputs.last_hidden_state[:, 0, :]
139
+ return F.normalize(text_features, p=2, dim=-1)
140
+
141
+ def get_location_features(self, coords: torch.Tensor) -> torch.Tensor:
142
+ """Encodes coordinates into features."""
143
+ if not self.config.use_location_encoder:
144
+ raise ValueError(
145
+ "Location encoder is not enabled for this model. Set `use_location_encoder=True` in config."
146
+ )
147
+ location_features = self.location_encoder(coords)
148
+ location_features = self.location_projection(location_features)
149
+ return F.normalize(location_features, p=2, dim=-1)
150
+
151
+ def forward(
152
+ self,
153
+ image: torch.Tensor,
154
+ input_ids: torch.Tensor,
155
+ attention_mask: torch.Tensor,
156
+ coords: torch.Tensor = None,
157
+ return_loss: bool = False,
158
+ ) -> CLOSPOutput:
159
+ image_embeds = self.get_image_features(image)
160
+ text_embeds = self.get_text_features(input_ids, attention_mask)
161
+
162
+ # Cosine similarity as logits
163
+ logits_per_image = image_embeds @ text_embeds.T
164
+ logits_per_text = logits_per_image.T
165
+
166
+ # --- Optional Location Logic ---
167
+ location_embeds = None
168
+ logits_per_loc_img = None
169
+ logits_per_img_loc = None
170
+
171
+ if self.config.use_location_encoder:
172
+ if coords is None:
173
+ raise ValueError(
174
+ "Coordinates must be provided when use_location_encoder is True."
175
+ )
176
+ location_embeds = self.get_location_features(coords)
177
+ logits_per_loc_img = location_embeds @ image_embeds.T
178
+ logits_per_img_loc = image_embeds @ location_embeds.T
179
+
180
+ # --- Optional Loss Calculation ---
181
+ loss = None
182
+ if return_loss:
183
+ outputs = [
184
+ logits_per_image,
185
+ logits_per_text,
186
+ logits_per_loc_img,
187
+ logits_per_img_loc,
188
+ ]
189
+ ground_truth = torch.arange(len(input_ids)).to(self.device)
190
+ loss = [F.cross_entropy(o, ground_truth) for o in outputs if o is not None]
191
+ loss = sum(loss) / len(loss)
192
+
193
+ return CLOSPOutput(
194
+ loss=loss,
195
+ logits_per_image=logits_per_image,
196
+ logits_per_text=logits_per_text,
197
+ logits_per_loc_img=logits_per_loc_img,
198
+ logits_per_img_loc=logits_per_img_loc,
199
+ image_embeds=image_embeds,
200
+ text_embeds=text_embeds,
201
+ location_embeds=location_embeds,
202
+ )