add unmerged fixes (#1)
Browse files- update pr (afa2fbbd041133a280b818ea4b23a9e61634384c)
- config.json +1 -1
- configuration_florence2.py +2 -2
- modeling_florence2.py +14 -11
config.json
CHANGED
@@ -80,6 +80,6 @@
|
|
80 |
},
|
81 |
"vocab_size": 51289,
|
82 |
"torch_dtype": "float16",
|
83 |
-
"transformers_version": "4.
|
84 |
"is_encoder_decoder": true
|
85 |
}
|
|
|
80 |
},
|
81 |
"vocab_size": 51289,
|
82 |
"torch_dtype": "float16",
|
83 |
+
"transformers_version": "4.49.0",
|
84 |
"is_encoder_decoder": true
|
85 |
}
|
configuration_florence2.py
CHANGED
@@ -77,7 +77,7 @@ class Florence2VisionConfig(PretrainedConfig):
|
|
77 |
>>> configuration = model.config
|
78 |
```"""
|
79 |
|
80 |
-
model_type = "
|
81 |
keys_to_ignore_at_inference = ["past_key_values"]
|
82 |
|
83 |
def __init__(
|
@@ -327,7 +327,7 @@ class Florence2Config(PretrainedConfig):
|
|
327 |
self.vocab_size = vocab_size
|
328 |
self.projection_dim = projection_dim
|
329 |
if vision_config is not None:
|
330 |
-
vision_config =
|
331 |
self.vision_config = vision_config
|
332 |
self.vocab_size = self.vocab_size
|
333 |
|
|
|
77 |
>>> configuration = model.config
|
78 |
```"""
|
79 |
|
80 |
+
model_type = "davit"
|
81 |
keys_to_ignore_at_inference = ["past_key_values"]
|
82 |
|
83 |
def __init__(
|
|
|
327 |
self.vocab_size = vocab_size
|
328 |
self.projection_dim = projection_dim
|
329 |
if vision_config is not None:
|
330 |
+
vision_config = Florence2VisionConfig(**vision_config)
|
331 |
self.vision_config = vision_config
|
332 |
self.vocab_size = self.vocab_size
|
333 |
|
modeling_florence2.py
CHANGED
@@ -26,7 +26,7 @@ import torch.utils.checkpoint as checkpoint
|
|
26 |
from torch.nn import CrossEntropyLoss
|
27 |
from collections import OrderedDict
|
28 |
from einops import rearrange
|
29 |
-
from timm.
|
30 |
|
31 |
from transformers.modeling_utils import PreTrainedModel
|
32 |
from transformers.generation.utils import GenerationMixin
|
@@ -2080,8 +2080,8 @@ class Florence2LanguageForConditionalGeneration(Florence2LanguagePreTrainedModel
|
|
2080 |
def get_decoder(self):
|
2081 |
return self.model.get_decoder()
|
2082 |
|
2083 |
-
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None) -> nn.Embedding:
|
2084 |
-
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
2085 |
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
|
2086 |
return new_embeddings
|
2087 |
|
@@ -2589,8 +2589,8 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2589 |
def get_input_embeddings(self):
|
2590 |
return self.language_model.get_input_embeddings()
|
2591 |
|
2592 |
-
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None) -> nn.Embedding:
|
2593 |
-
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of)
|
2594 |
# update vocab size
|
2595 |
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
2596 |
self.config.vocab_size = model_embeds.num_embeddings
|
@@ -2644,7 +2644,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2644 |
return x
|
2645 |
|
2646 |
def _merge_input_ids_with_image_features(
|
2647 |
-
self, image_features, inputs_embeds
|
2648 |
):
|
2649 |
batch_size, image_token_length = image_features.size()[:-1]
|
2650 |
device = image_features.device
|
@@ -2656,10 +2656,11 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2656 |
return image_features, image_attention_mask
|
2657 |
|
2658 |
task_prefix_embeds = inputs_embeds
|
2659 |
-
task_prefix_attention_mask
|
|
|
2660 |
|
2661 |
-
|
2662 |
-
|
2663 |
|
2664 |
# concat [image embeds, task prefix embeds]
|
2665 |
inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
|
@@ -2735,7 +2736,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2735 |
if pixel_values is not None:
|
2736 |
# (batch_size, num_image_tokens, hidden_size)
|
2737 |
image_features = self._encode_image(pixel_values)
|
2738 |
-
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
|
2739 |
|
2740 |
if inputs_embeds is not None:
|
2741 |
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
@@ -2782,6 +2783,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2782 |
input_ids,
|
2783 |
inputs_embeds=None,
|
2784 |
pixel_values=None,
|
|
|
2785 |
**kwargs
|
2786 |
):
|
2787 |
|
@@ -2792,11 +2794,12 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel):
|
|
2792 |
# 2. Merge text and images
|
2793 |
if pixel_values is not None:
|
2794 |
image_features = self._encode_image(pixel_values)
|
2795 |
-
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds)
|
2796 |
|
2797 |
return self.language_model.generate(
|
2798 |
input_ids=None,
|
2799 |
inputs_embeds=inputs_embeds,
|
|
|
2800 |
**kwargs
|
2801 |
)
|
2802 |
|
|
|
26 |
from torch.nn import CrossEntropyLoss
|
27 |
from collections import OrderedDict
|
28 |
from einops import rearrange
|
29 |
+
from timm.layers import DropPath, trunc_normal_
|
30 |
|
31 |
from transformers.modeling_utils import PreTrainedModel
|
32 |
from transformers.generation.utils import GenerationMixin
|
|
|
2080 |
def get_decoder(self):
|
2081 |
return self.model.get_decoder()
|
2082 |
|
2083 |
+
def resize_token_embeddings(self, new_num_tokens: int, pad_to_multiple_of: Optional[int] = None, **kwargs) -> nn.Embedding:
|
2084 |
+
new_embeddings = super().resize_token_embeddings(new_num_tokens, pad_to_multiple_of, **kwargs)
|
2085 |
self._resize_final_logits_bias(new_embeddings.weight.shape[0])
|
2086 |
return new_embeddings
|
2087 |
|
|
|
2589 |
def get_input_embeddings(self):
|
2590 |
return self.language_model.get_input_embeddings()
|
2591 |
|
2592 |
+
def resize_token_embeddings(self, new_num_tokens: Optional[int] = None, pad_to_multiple_of=None, **kwargs) -> nn.Embedding:
|
2593 |
+
model_embeds = self.language_model.resize_token_embeddings(new_num_tokens, pad_to_multiple_of, **kwargs)
|
2594 |
# update vocab size
|
2595 |
self.config.text_config.vocab_size = model_embeds.num_embeddings
|
2596 |
self.config.vocab_size = model_embeds.num_embeddings
|
|
|
2644 |
return x
|
2645 |
|
2646 |
def _merge_input_ids_with_image_features(
|
2647 |
+
self, image_features, inputs_embeds, task_prefix_attention_mask=None
|
2648 |
):
|
2649 |
batch_size, image_token_length = image_features.size()[:-1]
|
2650 |
device = image_features.device
|
|
|
2656 |
return image_features, image_attention_mask
|
2657 |
|
2658 |
task_prefix_embeds = inputs_embeds
|
2659 |
+
if task_prefix_attention_mask is None:
|
2660 |
+
task_prefix_attention_mask = torch.ones(batch_size, task_prefix_embeds.size(1), device=device)
|
2661 |
|
2662 |
+
if len(task_prefix_attention_mask.shape) == 3:
|
2663 |
+
task_prefix_attention_mask = task_prefix_attention_mask[:, 0]
|
2664 |
|
2665 |
# concat [image embeds, task prefix embeds]
|
2666 |
inputs_embeds = torch.cat([image_features, task_prefix_embeds], dim=1)
|
|
|
2736 |
if pixel_values is not None:
|
2737 |
# (batch_size, num_image_tokens, hidden_size)
|
2738 |
image_features = self._encode_image(pixel_values)
|
2739 |
+
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=attention_mask)
|
2740 |
|
2741 |
if inputs_embeds is not None:
|
2742 |
attention_mask = attention_mask.to(inputs_embeds.dtype)
|
|
|
2783 |
input_ids,
|
2784 |
inputs_embeds=None,
|
2785 |
pixel_values=None,
|
2786 |
+
attention_mask=None,
|
2787 |
**kwargs
|
2788 |
):
|
2789 |
|
|
|
2794 |
# 2. Merge text and images
|
2795 |
if pixel_values is not None:
|
2796 |
image_features = self._encode_image(pixel_values)
|
2797 |
+
inputs_embeds, attention_mask = self._merge_input_ids_with_image_features(image_features, inputs_embeds, task_prefix_attention_mask=attention_mask)
|
2798 |
|
2799 |
return self.language_model.generate(
|
2800 |
input_ids=None,
|
2801 |
inputs_embeds=inputs_embeds,
|
2802 |
+
attention_mask=attention_mask,
|
2803 |
**kwargs
|
2804 |
)
|
2805 |
|