Finalize Sentence Transformers integration

#3
by tomaarsen HF staff - opened
Files changed (4) hide show
  1. 1_Pooling/config.json +10 -0
  2. README.md +49 -7
  3. custom_st.py +22 -31
  4. modules.json +20 -0
1_Pooling/config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "word_embedding_dimension": 4096,
3
+ "pooling_mode_cls_token": false,
4
+ "pooling_mode_mean_tokens": false,
5
+ "pooling_mode_max_tokens": false,
6
+ "pooling_mode_mean_sqrt_len_tokens": false,
7
+ "pooling_mode_weightedmean_tokens": false,
8
+ "pooling_mode_lasttoken": true,
9
+ "include_prompt": true
10
+ }
README.md CHANGED
@@ -2,6 +2,7 @@
2
  tags:
3
  - mmeb
4
  - transformers
 
5
  language:
6
  - en
7
  - ar
@@ -34,15 +35,10 @@ Our model achieves SOTA performance on MMEB benchmark.
34
 
35
  ## Usage
36
 
37
- Below is an example we adapted from [VLM2Vec](https://huggingface.co/TIGER-Lab/VLM2Vec-Full).
38
 
39
- First clone github
40
- ```bash
41
- git clone https://github.com/haon-chen/mmE5.git
42
- pip install -r requirements.txt
43
- ```
44
 
45
- Then you can enter the directory to run the following command.
46
  ```python
47
  import torch
48
  import requests
@@ -107,6 +103,52 @@ print(string, '=', compute_similarity(qry_output, tgt_output))
107
  ## <|image|><|begin_of_text|> Represent the given image. = tensor([[0.3887]], device='cuda:0', dtype=torch.bfloat16)
108
  ```
109
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
  ## Citation
111
  ```
112
  @article{chen2025mmE5,
 
2
  tags:
3
  - mmeb
4
  - transformers
5
+ - sentence-transformers
6
  language:
7
  - en
8
  - ar
 
35
 
36
  ## Usage
37
 
38
+ ### Transformers
39
 
40
+ Below is an example we adapted from [VLM2Vec](https://huggingface.co/TIGER-Lab/VLM2Vec-Full).
 
 
 
 
41
 
 
42
  ```python
43
  import torch
44
  import requests
 
103
  ## <|image|><|begin_of_text|> Represent the given image. = tensor([[0.3887]], device='cuda:0', dtype=torch.bfloat16)
104
  ```
105
 
106
+ ### Sentence Transformers
107
+
108
+ You can also use Sentence Transformers, where the majority of the pre- and post-processing has been abstracted.
109
+
110
+ ```python
111
+ from sentence_transformers import SentenceTransformer
112
+ import requests
113
+
114
+ # Load the model
115
+ model = SentenceTransformer("intfloat/mmE5-mllama-11b-instruct", trust_remote_code=True)
116
+
117
+ # Download an example image of a cat and a dog
118
+ dog_cat_image_bytes = requests.get('https://github.com/haon-chen/mmE5/blob/main/figures/example.jpg?raw=true', stream=True).raw.read()
119
+ with open("cat_dog_example.jpg", "wb") as f:
120
+ f.write(dog_cat_image_bytes)
121
+
122
+ # Image + Text -> Text
123
+ image_embeddings = model.encode([{
124
+ "image": "cat_dog_example.jpg",
125
+ "text": "Represent the given image with the following question: What is in the image",
126
+ }])
127
+ text_embeddings = model.encode([
128
+ {"text": "A cat and a dog"},
129
+ {"text": "A cat and a tiger"},
130
+ ])
131
+
132
+ similarity = model.similarity(image_embeddings, text_embeddings)
133
+ print(similarity)
134
+ # tensor([[0.3967, 0.3090]])
135
+ # ✅ The first text is most similar to the image
136
+
137
+ # Text -> Image
138
+ image_embeddings = model.encode([
139
+ {"image": dog_cat_image_bytes, "text": "Represent the given image."},
140
+ ])
141
+ text_embeddings = model.encode([
142
+ {"text": "Find me an everyday image that matches the given caption: A cat and a dog."},
143
+ {"text": "Find me an everyday image that matches the given caption: A cat and a tiger."},
144
+ ])
145
+
146
+ similarity = model.similarity(image_embeddings, text_embeddings)
147
+ print(similarity)
148
+ # tensor([[0.4250, 0.3896]])
149
+ # ✅ The first text is most similar to the image
150
+ ```
151
+
152
  ## Citation
153
  ```
154
  @article{chen2025mmE5,
custom_st.py CHANGED
@@ -17,6 +17,7 @@ class MultiModalTransformer(BaseTransformer):
17
  super().__init__(model_name_or_path, **kwargs)
18
  if tokenizer_args is None:
19
  tokenizer_args = {}
 
20
 
21
  # Initialize processor
22
  self.processor = AutoProcessor.from_pretrained(
@@ -32,6 +33,7 @@ class MultiModalTransformer(BaseTransformer):
32
  is_peft_model: bool,
33
  **model_args,
34
  ) -> None:
 
35
  self.auto_model = MllamaForConditionalGeneration.from_pretrained(
36
  model_name_or_path, torch_dtype=torch.bfloat16, cache_dir=cache_dir, **model_args
37
  )
@@ -47,49 +49,38 @@ class MultiModalTransformer(BaseTransformer):
47
  **kwargs
48
  )
49
 
50
- # Apply last pooling and normalization
51
- last_hidden_state = outputs.hidden_states[-1]
52
- attention_mask = features["attention_mask"]
53
- sentence_embedding = self._last_pooling(last_hidden_state, attention_mask)
54
-
55
- features.update({"sentence_embedding": sentence_embedding})
56
  return features
57
 
58
- def _last_pooling(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
59
- """Apply last token pooling and L2 normalization"""
60
- sequence_lengths = attention_mask.sum(dim=1) - 1
61
- batch_size = last_hidden_state.shape[0]
62
- reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
63
- return torch.nn.functional.normalize(reps, p=2, dim=-1)
64
-
65
  def tokenize(self, texts: List[List[Dict]] | List[str]) -> Dict[str, torch.Tensor]:
66
  def process_text_item(item):
67
  if isinstance(item, str):
68
- return item, []
69
 
70
- text, images = "", []
71
- for sub_item in item:
72
- if sub_item["type"] == "text":
73
- text += sub_item["content"]
74
- elif sub_item["type"] in ["image_bytes", "image_path"]:
75
- text += "<|image|>"
76
- if sub_item["type"] == "image_bytes":
77
- img = Image.open(BytesIO(sub_item["content"])).convert("RGB")
78
- else:
79
- img = Image.open(sub_item["content"]).convert("RGB")
80
- images.append(img)
81
- else:
82
- raise ValueError(f"Unknown data type {sub_item['type']}")
83
- return text, images
 
 
84
 
85
  all_texts, all_images = [], []
86
  for item in texts:
87
  text, images = process_text_item(item)
88
  all_texts.append(text)
89
- all_images.extend(images)
90
 
91
- # Process inputs through the processor
92
- if all_images:
93
  inputs = self.processor(
94
  text=all_texts,
95
  images=all_images,
 
17
  super().__init__(model_name_or_path, **kwargs)
18
  if tokenizer_args is None:
19
  tokenizer_args = {}
20
+ tokenizer_args.pop("trust_remote_code", None)
21
 
22
  # Initialize processor
23
  self.processor = AutoProcessor.from_pretrained(
 
33
  is_peft_model: bool,
34
  **model_args,
35
  ) -> None:
36
+ model_args.pop("trust_remote_code", None)
37
  self.auto_model = MllamaForConditionalGeneration.from_pretrained(
38
  model_name_or_path, torch_dtype=torch.bfloat16, cache_dir=cache_dir, **model_args
39
  )
 
49
  **kwargs
50
  )
51
 
52
+ features.update({"token_embeddings": outputs.hidden_states[-1]})
 
 
 
 
 
53
  return features
54
 
 
 
 
 
 
 
 
55
  def tokenize(self, texts: List[List[Dict]] | List[str]) -> Dict[str, torch.Tensor]:
56
  def process_text_item(item):
57
  if isinstance(item, str):
58
+ return item, None
59
 
60
+ text, img = "", None
61
+ if "image" in item:
62
+ text += "<|image|>"
63
+ img = item["image"]
64
+ if isinstance(img, bytes):
65
+ img = Image.open(BytesIO(img)).convert("RGB")
66
+ elif isinstance(img, str):
67
+ img = Image.open(img).convert("RGB")
68
+ elif not isinstance(img, Image):
69
+ raise ValueError(f"Unknown image type {type(img)}")
70
+ if "text" in item:
71
+ if text:
72
+ text += "<|begin_of_text|> "
73
+ text += item["text"].lstrip()
74
+
75
+ return text, img
76
 
77
  all_texts, all_images = [], []
78
  for item in texts:
79
  text, images = process_text_item(item)
80
  all_texts.append(text)
81
+ all_images.append(images)
82
 
83
+ if all_images != [None] * len(all_images):
 
84
  inputs = self.processor(
85
  text=all_texts,
86
  images=all_images,
modules.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "idx": 0,
4
+ "name": "0",
5
+ "path": "",
6
+ "type": "custom_st.MultiModalTransformer"
7
+ },
8
+ {
9
+ "idx": 1,
10
+ "name": "1",
11
+ "path": "1_Pooling",
12
+ "type": "sentence_transformers.models.Pooling"
13
+ },
14
+ {
15
+ "idx": 2,
16
+ "name": "2",
17
+ "path": "2_Normalize",
18
+ "type": "sentence_transformers.models.Normalize"
19
+ }
20
+ ]