Commit
·
cdebfc7
1
Parent(s):
ae96581
fix: various fixes
Browse files- configuration_clip.py +4 -0
- eva_model.py +2 -1
- hf_model.py +67 -24
- modeling_clip.py +20 -27
- rope_embeddings.py +4 -9
configuration_clip.py
CHANGED
|
@@ -24,6 +24,8 @@ class JinaCLIPTextConfig(PretrainedConfig):
|
|
| 24 |
embed_dim: int = 768,
|
| 25 |
hf_model_name_or_path: str = 'jinaai/jina-bert-flash-implementation',
|
| 26 |
hf_model_config_kwargs: Optional[Dict[str, Any]] = None,
|
|
|
|
|
|
|
| 27 |
pooler_type: Optional[str] = None,
|
| 28 |
proj_type: Optional[str] = None,
|
| 29 |
proj_bias: bool = False,
|
|
@@ -34,6 +36,8 @@ class JinaCLIPTextConfig(PretrainedConfig):
|
|
| 34 |
self.embed_dim = embed_dim
|
| 35 |
self.hf_model_name_or_path = hf_model_name_or_path
|
| 36 |
self.hf_model_config_kwargs = hf_model_config_kwargs or {}
|
|
|
|
|
|
|
| 37 |
self.pooler_type = pooler_type
|
| 38 |
self.proj_type = proj_type
|
| 39 |
self.proj_bias = proj_bias
|
|
|
|
| 24 |
embed_dim: int = 768,
|
| 25 |
hf_model_name_or_path: str = 'jinaai/jina-bert-flash-implementation',
|
| 26 |
hf_model_config_kwargs: Optional[Dict[str, Any]] = None,
|
| 27 |
+
default_instruction_task: Optional[str] = None,
|
| 28 |
+
default_lora_task: Optional[str] = None,
|
| 29 |
pooler_type: Optional[str] = None,
|
| 30 |
proj_type: Optional[str] = None,
|
| 31 |
proj_bias: bool = False,
|
|
|
|
| 36 |
self.embed_dim = embed_dim
|
| 37 |
self.hf_model_name_or_path = hf_model_name_or_path
|
| 38 |
self.hf_model_config_kwargs = hf_model_config_kwargs or {}
|
| 39 |
+
self.default_instruction_task = default_instruction_task
|
| 40 |
+
self.default_lora_task = default_lora_task
|
| 41 |
self.pooler_type = pooler_type
|
| 42 |
self.proj_type = proj_type
|
| 43 |
self.proj_bias = proj_bias
|
eva_model.py
CHANGED
|
@@ -12,7 +12,8 @@ import torch.nn as nn
|
|
| 12 |
import torch.nn.functional as f
|
| 13 |
|
| 14 |
try:
|
| 15 |
-
from timm.models.layers import drop_path as timm_drop_path
|
|
|
|
| 16 |
except ImportError or ModuleNotFoundError:
|
| 17 |
from timm.layers import drop_path as timm_drop_path, to_2tuple, trunc_normal_
|
| 18 |
|
|
|
|
| 12 |
import torch.nn.functional as f
|
| 13 |
|
| 14 |
try:
|
| 15 |
+
from timm.models.layers import drop_path as timm_drop_path
|
| 16 |
+
from timm.models.layers import to_2tuple, trunc_normal_
|
| 17 |
except ImportError or ModuleNotFoundError:
|
| 18 |
from timm.layers import drop_path as timm_drop_path, to_2tuple, trunc_normal_
|
| 19 |
|
hf_model.py
CHANGED
|
@@ -1,5 +1,7 @@
|
|
| 1 |
import re
|
|
|
|
| 2 |
from typing import Dict, Optional
|
|
|
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
|
@@ -9,7 +11,6 @@ from transformers.modeling_outputs import (
|
|
| 9 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 10 |
)
|
| 11 |
|
| 12 |
-
|
| 13 |
_HF_ARCH_DICT = {
|
| 14 |
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
|
| 15 |
'roberta': {
|
|
@@ -120,6 +121,8 @@ class HFTextEncoder(nn.Module):
|
|
| 120 |
trust_remote_code: bool = False,
|
| 121 |
revision: Optional[str] = None,
|
| 122 |
code_revision: Optional[str] = None,
|
|
|
|
|
|
|
| 123 |
model_config_kwargs: Optional[Dict] = None,
|
| 124 |
):
|
| 125 |
super().__init__()
|
|
@@ -129,39 +132,35 @@ class HFTextEncoder(nn.Module):
|
|
| 129 |
model_config_kwargs = model_config_kwargs or {}
|
| 130 |
|
| 131 |
if config is None:
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
revision=revision,
|
| 136 |
-
code_revision=code_revision,
|
| 137 |
-
)
|
| 138 |
-
self.config.update(model_config_kwargs)
|
| 139 |
-
create_func, model_args = (
|
| 140 |
-
(AutoModel.from_pretrained, model_name_or_path)
|
| 141 |
-
if pretrained
|
| 142 |
-
else (AutoModel.from_config, self.config)
|
| 143 |
-
)
|
| 144 |
-
if (
|
| 145 |
-
hasattr(self.config, 'is_encoder_decoder')
|
| 146 |
-
and self.config.is_encoder_decoder
|
| 147 |
-
):
|
| 148 |
-
self.transformer = create_func(
|
| 149 |
-
model_args,
|
| 150 |
trust_remote_code=trust_remote_code,
|
| 151 |
revision=revision,
|
|
|
|
| 152 |
code_revision=code_revision,
|
| 153 |
**model_config_kwargs,
|
| 154 |
)
|
| 155 |
-
self.
|
| 156 |
else:
|
| 157 |
-
self.
|
| 158 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
trust_remote_code=trust_remote_code,
|
| 160 |
-
revision=revision,
|
| 161 |
add_pooling_layer=False,
|
| 162 |
code_revision=code_revision,
|
| 163 |
-
**model_config_kwargs,
|
| 164 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
else:
|
| 166 |
self.config = config
|
| 167 |
self.config.update(model_config_kwargs)
|
|
@@ -209,6 +208,50 @@ class HFTextEncoder(nn.Module):
|
|
| 209 |
self._task_instructions = self.transformer._task_instructions
|
| 210 |
self._supports_task_instructions = True
|
| 211 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
@torch.jit.ignore
|
| 213 |
def set_grad_checkpointing(self, _=True):
|
| 214 |
self.transformer.gradient_checkpointing_enable()
|
|
|
|
| 1 |
import re
|
| 2 |
+
import warnings
|
| 3 |
from typing import Dict, Optional
|
| 4 |
+
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
from transformers import AutoConfig, AutoModel, PretrainedConfig
|
|
|
|
| 11 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
| 12 |
)
|
| 13 |
|
|
|
|
| 14 |
_HF_ARCH_DICT = {
|
| 15 |
# https://huggingface.co/docs/transformers/model_doc/roberta#roberta
|
| 16 |
'roberta': {
|
|
|
|
| 121 |
trust_remote_code: bool = False,
|
| 122 |
revision: Optional[str] = None,
|
| 123 |
code_revision: Optional[str] = None,
|
| 124 |
+
default_instruction_task: Optional[str] = None,
|
| 125 |
+
default_lora_task: Optional[str] = None,
|
| 126 |
model_config_kwargs: Optional[Dict] = None,
|
| 127 |
):
|
| 128 |
super().__init__()
|
|
|
|
| 132 |
model_config_kwargs = model_config_kwargs or {}
|
| 133 |
|
| 134 |
if config is None:
|
| 135 |
+
if pretrained:
|
| 136 |
+
self.transformer = AutoModel.from_pretrained(
|
| 137 |
+
model_name_or_path,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
trust_remote_code=trust_remote_code,
|
| 139 |
revision=revision,
|
| 140 |
+
add_pooling_layer=False,
|
| 141 |
code_revision=code_revision,
|
| 142 |
**model_config_kwargs,
|
| 143 |
)
|
| 144 |
+
self.config = self.transformer.config
|
| 145 |
else:
|
| 146 |
+
self.config = AutoConfig.from_pretrained(
|
| 147 |
+
model_name_or_path,
|
| 148 |
+
trust_remote_code=trust_remote_code,
|
| 149 |
+
code_revision=code_revision,
|
| 150 |
+
)
|
| 151 |
+
self.config.update(model_config_kwargs)
|
| 152 |
+
self.transformer = AutoModel.from_config(
|
| 153 |
+
self.config,
|
| 154 |
trust_remote_code=trust_remote_code,
|
|
|
|
| 155 |
add_pooling_layer=False,
|
| 156 |
code_revision=code_revision,
|
|
|
|
| 157 |
)
|
| 158 |
+
if (
|
| 159 |
+
hasattr(self.config, 'is_encoder_decoder')
|
| 160 |
+
and self.config.is_encoder_decoder
|
| 161 |
+
):
|
| 162 |
+
self.transformer = self.transformer.encoder
|
| 163 |
+
|
| 164 |
else:
|
| 165 |
self.config = config
|
| 166 |
self.config.update(model_config_kwargs)
|
|
|
|
| 208 |
self._task_instructions = self.transformer._task_instructions
|
| 209 |
self._supports_task_instructions = True
|
| 210 |
|
| 211 |
+
self.default_instruction_task = None
|
| 212 |
+
self.default_lora_task = None
|
| 213 |
+
self.default_instruction = None
|
| 214 |
+
self.default_loraid = None
|
| 215 |
+
if default_instruction_task is not None:
|
| 216 |
+
self.default_instruction_task = default_instruction_task
|
| 217 |
+
self.default_instruction = self.get_instruction_from_task(
|
| 218 |
+
default_instruction_task
|
| 219 |
+
)
|
| 220 |
+
if default_lora_task is not None:
|
| 221 |
+
self.default_lora_task = default_lora_task
|
| 222 |
+
self.default_loraid = self.get_loraid_from_task(default_lora_task)
|
| 223 |
+
|
| 224 |
+
def get_instruction_from_task(self, task: str) -> Optional[str]:
|
| 225 |
+
if self._supports_task_instructions:
|
| 226 |
+
if task not in self._task_instructions:
|
| 227 |
+
raise ValueError(
|
| 228 |
+
f'Unsupported task \'{task}\'. Choose one of the following: '
|
| 229 |
+
f'{", ".join(self._task_instructions)} or set to None to disable '
|
| 230 |
+
f'task instructions completely'
|
| 231 |
+
)
|
| 232 |
+
return self._task_instructions[task]
|
| 233 |
+
else:
|
| 234 |
+
warnings.warn(
|
| 235 |
+
'Model does not support task instructions, ignoring instruction '
|
| 236 |
+
f"task '{task}'"
|
| 237 |
+
)
|
| 238 |
+
return None
|
| 239 |
+
|
| 240 |
+
def get_loraid_from_task(self, task: str) -> Optional[int]:
|
| 241 |
+
if self._supports_lora:
|
| 242 |
+
if task not in self._lora_adaptation_map:
|
| 243 |
+
raise ValueError(
|
| 244 |
+
f'Unsupported task \'{task}\'. Choose one of the following: '
|
| 245 |
+
f'{", ".join(self._task_instructions)} or set to None to disable '
|
| 246 |
+
f'the LoRA adapters completely'
|
| 247 |
+
)
|
| 248 |
+
return self._lora_adaptation_map[task]
|
| 249 |
+
else:
|
| 250 |
+
warnings.warn(
|
| 251 |
+
f"Model does not support LoRA adapters, ignoring LoRA task '{task}'"
|
| 252 |
+
)
|
| 253 |
+
return None
|
| 254 |
+
|
| 255 |
@torch.jit.ignore
|
| 256 |
def set_grad_checkpointing(self, _=True):
|
| 257 |
self.transformer.gradient_checkpointing_enable()
|
modeling_clip.py
CHANGED
|
@@ -68,6 +68,8 @@ def _build_text_tower(config: JinaCLIPTextConfig) -> HFTextEncoder:
|
|
| 68 |
return HFTextEncoder(
|
| 69 |
model_name_or_path=config.hf_model_name_or_path,
|
| 70 |
output_dim=config.embed_dim,
|
|
|
|
|
|
|
| 71 |
pooler_type=config.pooler_type,
|
| 72 |
proj_type=config.proj_type,
|
| 73 |
proj_bias=config.proj_bias,
|
|
@@ -532,33 +534,25 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
| 532 |
|
| 533 |
truncate_dim = truncate_dim or self.config.truncate_dim
|
| 534 |
|
| 535 |
-
|
|
|
|
| 536 |
if task:
|
| 537 |
-
|
| 538 |
-
|
| 539 |
-
|
| 540 |
-
|
| 541 |
-
|
| 542 |
-
|
| 543 |
-
|
| 544 |
-
|
| 545 |
-
|
| 546 |
-
|
| 547 |
-
|
| 548 |
-
|
| 549 |
-
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
| 553 |
-
elif task not in self.text_model._task_instructions:
|
| 554 |
-
raise ValueError(
|
| 555 |
-
f'Unsupported task \'{task}\'. Choose one of the following: '
|
| 556 |
-
f'{", ".join(self.text_model._task_instructions)} or bypass the '
|
| 557 |
-
'`task` argument to disable task instructions completely.'
|
| 558 |
-
)
|
| 559 |
-
else:
|
| 560 |
-
instruction = self.text_model._task_instructions[task]
|
| 561 |
-
sentences = [instruction + sentence for sentence in sentences]
|
| 562 |
|
| 563 |
for i in range_iter:
|
| 564 |
tokens = self.tokenizer(
|
|
@@ -566,7 +560,6 @@ class JinaCLIPModel(JinaCLIPPreTrainedModel):
|
|
| 566 |
return_tensors='pt',
|
| 567 |
**tokenizer_kwargs,
|
| 568 |
).to(self.device)
|
| 569 |
-
|
| 570 |
embeddings = self.get_text_features(
|
| 571 |
input_ids=tokens, adapter_mask=adapter_mask
|
| 572 |
)
|
|
|
|
| 68 |
return HFTextEncoder(
|
| 69 |
model_name_or_path=config.hf_model_name_or_path,
|
| 70 |
output_dim=config.embed_dim,
|
| 71 |
+
default_instruction_task=config.default_instruction_task,
|
| 72 |
+
default_lora_task=config.default_lora_task,
|
| 73 |
pooler_type=config.pooler_type,
|
| 74 |
proj_type=config.proj_type,
|
| 75 |
proj_bias=config.proj_bias,
|
|
|
|
| 534 |
|
| 535 |
truncate_dim = truncate_dim or self.config.truncate_dim
|
| 536 |
|
| 537 |
+
instruction = self.text_model.default_instruction
|
| 538 |
+
loraid = self.text_model.default_loraid
|
| 539 |
if task:
|
| 540 |
+
_selected_instruction = self.text_model.get_instruction_from_task(task)
|
| 541 |
+
if _selected_instruction is not None:
|
| 542 |
+
instruction = _selected_instruction
|
| 543 |
+
_selected_loraid = self.text_model.get_loraid_from_task(task)
|
| 544 |
+
if _selected_loraid is not None:
|
| 545 |
+
loraid = _selected_loraid
|
| 546 |
+
|
| 547 |
+
if instruction is not None:
|
| 548 |
+
sentences = [instruction + sentence for sentence in sentences]
|
| 549 |
+
|
| 550 |
+
adapter_mask = None
|
| 551 |
+
if loraid is not None:
|
| 552 |
+
nexamples = 1 if isinstance(sentences, str) else len(sentences)
|
| 553 |
+
adapter_mask = torch.full(
|
| 554 |
+
(nexamples,), loraid, dtype=torch.int32, device=self.device
|
| 555 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 556 |
|
| 557 |
for i in range_iter:
|
| 558 |
tokens = self.tokenizer(
|
|
|
|
| 560 |
return_tensors='pt',
|
| 561 |
**tokenizer_kwargs,
|
| 562 |
).to(self.device)
|
|
|
|
| 563 |
embeddings = self.get_text_features(
|
| 564 |
input_ids=tokens, adapter_mask=adapter_mask
|
| 565 |
)
|
rope_embeddings.py
CHANGED
|
@@ -3,7 +3,6 @@
|
|
| 3 |
# https://github.com/baaivision/EVA/tree/master/EVA-CLIP/rei/eva_clip
|
| 4 |
# --------------------------------------------------------
|
| 5 |
|
| 6 |
-
import logging
|
| 7 |
from math import pi
|
| 8 |
|
| 9 |
import torch
|
|
@@ -75,10 +74,8 @@ class VisionRotaryEmbedding(nn.Module):
|
|
| 75 |
|
| 76 |
freqs = broadcast((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
|
| 77 |
|
| 78 |
-
self.register_buffer('freqs_cos', freqs.cos())
|
| 79 |
-
self.register_buffer('freqs_sin', freqs.sin())
|
| 80 |
-
|
| 81 |
-
logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
|
| 82 |
|
| 83 |
def forward(self, t, start_index=0):
|
| 84 |
rot_dim = self.freqs_cos.shape[-1]
|
|
@@ -137,10 +134,8 @@ class VisionRotaryEmbeddingFast(nn.Module):
|
|
| 137 |
|
| 138 |
self.patch_dropout = patch_dropout
|
| 139 |
|
| 140 |
-
self.register_buffer('freqs_cos', freqs_cos)
|
| 141 |
-
self.register_buffer('freqs_sin', freqs_sin)
|
| 142 |
-
|
| 143 |
-
logging.info(f'Shape of rope freq: {self.freqs_cos.shape}')
|
| 144 |
|
| 145 |
def forward(self, t, patch_indices_keep=None):
|
| 146 |
if patch_indices_keep is not None:
|
|
|
|
| 3 |
# https://github.com/baaivision/EVA/tree/master/EVA-CLIP/rei/eva_clip
|
| 4 |
# --------------------------------------------------------
|
| 5 |
|
|
|
|
| 6 |
from math import pi
|
| 7 |
|
| 8 |
import torch
|
|
|
|
| 74 |
|
| 75 |
freqs = broadcast((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
|
| 76 |
|
| 77 |
+
self.register_buffer('freqs_cos', freqs.cos(), persistent=False)
|
| 78 |
+
self.register_buffer('freqs_sin', freqs.sin(), persistent=False)
|
|
|
|
|
|
|
| 79 |
|
| 80 |
def forward(self, t, start_index=0):
|
| 81 |
rot_dim = self.freqs_cos.shape[-1]
|
|
|
|
| 134 |
|
| 135 |
self.patch_dropout = patch_dropout
|
| 136 |
|
| 137 |
+
self.register_buffer('freqs_cos', freqs_cos, persistent=False)
|
| 138 |
+
self.register_buffer('freqs_sin', freqs_sin, persistent=False)
|
|
|
|
|
|
|
| 139 |
|
| 140 |
def forward(self, t, patch_indices_keep=None):
|
| 141 |
if patch_indices_keep is not None:
|