Create analogy_encoder.py
Browse files
analogy_encoder/analogy_encoder.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
ADOBE CONFIDENTIAL
|
| 3 |
+
Copyright 2024 Adobe
|
| 4 |
+
All Rights Reserved.
|
| 5 |
+
NOTICE: All information contained herein is, and remains
|
| 6 |
+
the property of Adobe and its suppliers, if any. The intellectual
|
| 7 |
+
and technical concepts contained herein are proprietary to Adobe
|
| 8 |
+
and its suppliers and are protected by all applicable intellectual
|
| 9 |
+
property laws, including trade secret and copyright laws.
|
| 10 |
+
Dissemination of this information or reproduction of this material
|
| 11 |
+
is strictly forbidden unless prior written permission is obtained
|
| 12 |
+
from Adobe.
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
import torch as th
|
| 16 |
+
from diffusers import ModelMixin
|
| 17 |
+
from transformers import AutoModel, SiglipVisionConfig, Dinov2Config
|
| 18 |
+
from transformers import SiglipVisionModel
|
| 19 |
+
|
| 20 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 21 |
+
|
| 22 |
+
class AnalogyEncoder(ModelMixin, ConfigMixin):
|
| 23 |
+
@register_to_config
|
| 24 |
+
def __init__(self, load_pretrained=False,
|
| 25 |
+
dino_config_dict=None, siglip_config_dict=None):
|
| 26 |
+
super().__init__()
|
| 27 |
+
if load_pretrained:
|
| 28 |
+
image_encoder_dino = AutoModel.from_pretrained('facebook/dinov2-large', torch_dtype=th.float16)
|
| 29 |
+
image_encoder_siglip = SiglipVisionModel.from_pretrained("google/siglip-large-patch16-256", torch_dtype=th.float16, attn_implementation="sdpa")
|
| 30 |
+
else:
|
| 31 |
+
image_encoder_dino = AutoModel.from_config(Dinov2Config.from_dict(dino_config_dict))
|
| 32 |
+
image_encoder_siglip = AutoModel.from_config(SiglipVisionConfig.from_dict(siglip_config_dict))
|
| 33 |
+
|
| 34 |
+
image_encoder_dino.requires_grad_(False)
|
| 35 |
+
image_encoder_dino = image_encoder_dino.to(memory_format=th.channels_last)
|
| 36 |
+
|
| 37 |
+
image_encoder_siglip.requires_grad_(False)
|
| 38 |
+
image_encoder_siglip = image_encoder_siglip.to(memory_format=th.channels_last)
|
| 39 |
+
self.image_encoder_dino = image_encoder_dino
|
| 40 |
+
self.image_encoder_siglip = image_encoder_siglip
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def dino_normalization(self, encoder_output):
|
| 44 |
+
embeds = encoder_output.last_hidden_state
|
| 45 |
+
embeds_pooled = embeds[:, 0:1]
|
| 46 |
+
embeds = embeds / th.norm(embeds_pooled, dim=-1, keepdim=True)
|
| 47 |
+
return embeds
|
| 48 |
+
|
| 49 |
+
def siglip_normalization(self, encoder_output):
|
| 50 |
+
embeds = th.cat ([encoder_output.pooler_output[:, None, :], encoder_output.last_hidden_state], dim=1)
|
| 51 |
+
embeds_pooled = embeds[:, 0:1]
|
| 52 |
+
embeds = embeds / th.norm(embeds_pooled, dim=-1, keepdim=True)
|
| 53 |
+
return embeds
|
| 54 |
+
|
| 55 |
+
def forward(self, dino_in, siglip_in):
|
| 56 |
+
|
| 57 |
+
x_1 = self.image_encoder_dino(dino_in, output_hidden_states=True)
|
| 58 |
+
x_1_first = x_1.hidden_states[0]
|
| 59 |
+
x_1 = self.dino_normalization(x_1)
|
| 60 |
+
x_2 = self.image_encoder_siglip(siglip_in, output_hidden_states=True)
|
| 61 |
+
x_2_first = x_2.hidden_states[0]
|
| 62 |
+
x_2_first_pool = th.mean(x_2_first, dim=1, keepdim=True)
|
| 63 |
+
x_2_first = th.cat([x_2_first_pool, x_2_first], 1)
|
| 64 |
+
x_2 = self.siglip_normalization(x_2)
|
| 65 |
+
dino_embd = th.cat([x_1, x_1_first], -1)
|
| 66 |
+
siglip_embd = th.cat([x_2, x_2_first], -1)
|
| 67 |
+
return dino_embd, siglip_embd
|
| 68 |
+
|