jcwang0602 nielsr HF Staff commited on
Commit
f4573ea
·
verified ·
1 Parent(s): 989019d

Improve model card: Add pipeline tag, library name, and paper/code links (#1)

Browse files

- Improve model card: Add pipeline tag, library name, and paper/code links (bad64b4f630d2e1b221a4de7654b62484f2b1238)


Co-authored-by: Niels Rogge <[email protected]>

Files changed (1) hide show
  1. README.md +205 -3
README.md CHANGED
@@ -1,3 +1,205 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ pipeline_tag: image-segmentation
4
+ library_name: transformers
5
+ ---
6
+
7
+ # MLLMSeg: Unlocking the Potential of MLLMs in Referring Expression Segmentation via a Light-weight Mask Decoder
8
+
9
+ This repository contains the `MLLMSeg` model, a novel framework for Referring Expression Segmentation (RES) and Generalized Referring Expression Segmentation (GRES), presented in the paper [Unlocking the Potential of MLLMs in Referring Expression Segmentation via a Light-weight Mask Decoder](https://huggingface.co/papers/2508.04107).
10
+
11
+ **Abstract:** Reference Expression Segmentation (RES) aims to segment image regions specified by referring expressions and has become popular with the rise of multimodal large models (MLLMs). While MLLMs excel in semantic understanding, their token-generation paradigm struggles with pixel-level dense prediction. Existing RES methods either couple MLLMs with the parameter-heavy Segment Anything Model (SAM) with 632M network parameters or adopt SAM-free lightweight pipelines that sacrifice accuracy. To address the trade-off between performance and cost, we specifically propose MLLMSeg, a novel framework that fully exploits the inherent visual detail features encoded in the MLLM vision encoder without introducing an extra visual encoder. Besides, we propose a detail-enhanced and semantic-consistent feature fusion module (DSFF) that fully integrates the detail-related visual feature with the semantic-related feature output by the large language model (LLM) of MLLM. Finally, we establish a light-weight mask decoder with only 34M network parameters that optimally leverages detailed spatial features from the visual encoder and semantic features from the LLM to achieve precise mask prediction. Extensive experiments demonstrate that our method generally surpasses both SAM-based and SAM-free competitors, striking a better balance between performance and cost.
12
+
13
+ <p align="center">
14
+ <img src="https://github.com/jcwang0602/MLLMSeg/raw/main/assets/method.png" width="800">
15
+ </p>
16
+
17
+ ## Paper and Code
18
+ * **Paper:** [Unlocking the Potential of MLLMs in Referring Expression Segmentation via a Light-weight Mask Decoder](https://huggingface.co/papers/2508.04107)
19
+ * **GitHub Repository:** [https://github.com/jcwang0602/MLLMSeg](https://github.com/jcwang0602/MLLMSeg)
20
+
21
+ ## Usage
22
+
23
+ You can use the `MLLMSeg` model with the `transformers` library. The model takes an image and a referring expression as input and outputs a segmentation mask or coordinates. Our models accept images of any size as input. The model outputs are normalized to relative coordinates within a 0-1000 range (either a center point or a bounding box defined by top-left and bottom-right coordinates). For visualization, remember to convert these relative coordinates back to the original image dimensions.
24
+
25
+ ### Installation
26
+
27
+ First, install the necessary dependencies:
28
+
29
+ ```bash
30
+ conda create -n mllmseg python==3.10.18 -y
31
+ conda activate mllmseg
32
+ pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu118
33
+ pip install -r requirements.txt
34
+ pip install flash-attn==2.3.6 --no-build-isolation # Note: need gpu to install
35
+ ```
36
+
37
+ ### Sample Usage
38
+
39
+ Here's a basic example demonstrating how to load and use the model for inference. We'll use the `MLLMSeg_InternVL2_5_8B_RES` model as an example.
40
+
41
+ ```python
42
+ import torch
43
+ import torchvision.transforms as T
44
+ from PIL import Image
45
+ from torchvision.transforms.functional import InterpolationMode
46
+ from transformers import AutoModel, AutoTokenizer
47
+ import requests
48
+ from io import BytesIO
49
+
50
+ # --- Helper functions for image preprocessing (from original GitHub repo) ---
51
+ IMAGENET_MEAN = (0.485, 0.456, 0.406)
52
+ IMAGENET_STD = (0.229, 0.224, 0.225)
53
+
54
+ def build_transform(input_size):
55
+ MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
56
+ transform = T.Compose([
57
+ T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
58
+ T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
59
+ T.ToTensor(),
60
+ T.Normalize(mean=MEAN, std=STD)
61
+ ])
62
+ return transform
63
+
64
+ def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
65
+ best_ratio_diff = float('inf')
66
+ best_ratio = (1, 1)
67
+ area = width * height
68
+ for ratio in target_ratios:
69
+ target_aspect_ratio = ratio[0] / ratio[1]
70
+ ratio_diff = abs(aspect_ratio - target_aspect_ratio)
71
+ if ratio_diff < best_ratio_diff:
72
+ best_ratio_diff = ratio_diff
73
+ best_ratio = ratio
74
+ elif ratio_diff == best_ratio_diff:
75
+ if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
76
+ best_ratio = ratio
77
+ return best_ratio
78
+
79
+ def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=True):
80
+ orig_width, orig_height = image.size
81
+ aspect_ratio = orig_width / orig_height
82
+
83
+ target_ratios = set(
84
+ (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
85
+ i * j <= max_num and i * j >= min_num)
86
+ target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
87
+
88
+ target_aspect_ratio = find_closest_aspect_ratio(
89
+ aspect_ratio, target_ratios, orig_width, orig_height, image_size)
90
+
91
+ target_width = image_size * target_aspect_ratio[0]
92
+ target_height = image_size * target_aspect_ratio[1]
93
+ blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
94
+
95
+ resized_img = image.resize((target_width, target_height))
96
+ processed_images = []
97
+ for i in range(blocks):
98
+ box = (
99
+ (i % (target_width // image_size)) * image_size,
100
+ (i // (target_width // image_size)) * image_size,
101
+ ((i % (target_width // image_size)) + 1) * image_size,
102
+ ((i // (target_width // image_size)) + 1) * image_size
103
+ )
104
+ split_img = resized_img.crop(box)
105
+ processed_images.append(split_img)
106
+ assert len(processed_images) == blocks
107
+ if use_thumbnail and len(processed_images) != 1:
108
+ thumbnail_img = image.resize((image_size, image_size))
109
+ processed_images.append(thumbnail_img)
110
+ return processed_images
111
+
112
+ def load_image(image_file_or_url, input_size=448, max_num=6):
113
+ if isinstance(image_file_or_url, str) and image_file_or_url.startswith("http"):
114
+ response = requests.get(image_file_or_url, stream=True)
115
+ image = Image.open(BytesIO(response.content)).convert('RGB')
116
+ else:
117
+ image = Image.open(image_file_or_url).convert('RGB')
118
+
119
+ transform = build_transform(input_size=input_size)
120
+ images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
121
+ pixel_values = [transform(image) for image in images]
122
+ pixel_values = torch.stack(pixel_values)
123
+ return pixel_values
124
+ # --- End of helper functions ---
125
+
126
+ # Load model and tokenizer
127
+ model_id = "jcwang0602/MLLMSeg_InternVL2_5_8B_RES"
128
+ model = AutoModel.from_pretrained(
129
+ model_id,
130
+ torch_dtype=torch.bfloat16,
131
+ low_cpu_mem_usage=True,
132
+ trust_remote_code=True
133
+ ).eval().cuda()
134
+ tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=False)
135
+
136
+ # Example image and question
137
+ # Using an example image from the MLLMSeg repository for demonstration
138
+ image_url = "https://huggingface.co/jcwang0602/MLLMSeg_InternVL2_5_8B_RES/resolve/main/assets/res_example.png"
139
+ question = "Please give me the segmentation mask of the dog (with [SEG])."
140
+
141
+ # Preprocess image
142
+ pixel_values = load_image(image_url, max_num=6).to(torch.bfloat16).cuda()
143
+ generation_config = dict(max_new_tokens=1024, do_sample=True)
144
+
145
+ # Generate response
146
+ response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=True)
147
+ print(f'User: {question}
148
+ Assistant: {response}')
149
+
150
+ # The output `response` will contain the segmentation information (e.g., coordinates or SEG token based output).
151
+ # You would then need to parse this string to extract the mask or coordinates for visualization.
152
+ ```
153
+
154
+ ## Checkpoints
155
+
156
+ Our checkpoints are available at:
157
+
158
+ | Base Model | RES Model | GRES Model |
159
+ |--------------|------------------------------------------------------------|------------------------------------------------------------|
160
+ | InternVL2_5_1B | [MLLMSeg_InternVL2_5_1B_RES](https://huggingface.co/jcwang0602/MLLMSeg_InternVL2_5_1B_RES) | - |
161
+ | InternVL2_5_2B | [MLLMSeg_InternVL2_5_2B_RES](https://huggingface.co/jcwang0602/MLLMSeg_InternVL2_5_2B_RES) | - |
162
+ | InternVL2_5_4B | [MLLMSeg_InternVL2_5_4B_RES](https://huggingface.co/jcwang0602/MLLMSeg_InternVL2_5_4B_RES) | - |
163
+ | InternVL2_5_8B | [MLLMSeg_InternVL2_5_8B_RES](https://huggingface.co/jcwang0602/MLLMSeg_InternVL2_5_8B_RES) | [MLLMSeg_InternVL2_5_8B_GRES](https://huggingface.co/jcwang0602/MLLMSeg_InternVL2_5_8B_GRES) |
164
+
165
+ ## Performance Metrics
166
+
167
+ ### Referring Expression Segmentation
168
+ <img src="https://github.com/jcwang0602/MLLMSeg/raw/main/assets/tab_res.png" width="800">
169
+
170
+ ### Referring Expression Comprehension
171
+ <img src="https://github.com/jcwang0602/MLLMSeg/raw/main/assets/tab_rec.png" width="800">
172
+
173
+ ### Generalized Referring Expression Segmentation
174
+ <img src="https://github.com/jcwang0602/MLLMSeg/raw/main/assets/tab_gres.png" width="800">
175
+
176
+ ## Visualization
177
+
178
+ ### Referring Expression Segmentation
179
+ <img src="https://github.com/jcwang0602/MLLMSeg/raw/main/assets/res.png" width="800">
180
+
181
+ ### Referring Expression Comprehension
182
+ <img src="https://github.com/jcwang0602/MLLMSeg/raw/main/assets/rec.png" width="800">
183
+
184
+ ### Generalized Referring Expression Segmentation
185
+ <img src="https://github.com/jcwang0602/MLLMSeg/raw/main/assets/gres.png" width="800">
186
+
187
+ ## Citation
188
+
189
+ If our work is useful for your research, please consider citing:
190
+
191
+ ```bibtex
192
+ @misc{wang2025unlockingpotentialmllmsreferring,
193
+ title={Unlocking the Potential of MLLMs in Referring Expression Segmentation via a Light-weight Mask Decoder},
194
+ author={Jingchao Wang and Zhijian Wu and Dingjiang Huang and Yefeng Zheng and Hong Wang},
195
+ year={2025},
196
+ eprint={2508.04107},
197
+ archivePrefix={arXiv},
198
+ primaryClass={cs.CV},
199
+ url={https://arxiv.org/abs/2508.04107},
200
+ }
201
+ ```
202
+
203
+ ## Acknowledgments
204
+
205
+ This code is developed on the top of [InternVL](https://github.com/OpenGVLab/InternVL), [GSVA](https://github.com/LeapLabTHU/GSVA), and [EEVG](https://github.com/chenwei746/EEVG).