Improve model card: Add pipeline tag, library name, and paper/code links (#1)
Browse files- Improve model card: Add pipeline tag, library name, and paper/code links (bad64b4f630d2e1b221a4de7654b62484f2b1238)
Co-authored-by: Niels Rogge <[email protected]>
README.md
CHANGED
@@ -1,3 +1,205 @@
|
|
1 |
-
---
|
2 |
-
license: mit
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
pipeline_tag: image-segmentation
|
4 |
+
library_name: transformers
|
5 |
+
---
|
6 |
+
|
7 |
+
# MLLMSeg: Unlocking the Potential of MLLMs in Referring Expression Segmentation via a Light-weight Mask Decoder
|
8 |
+
|
9 |
+
This repository contains the `MLLMSeg` model, a novel framework for Referring Expression Segmentation (RES) and Generalized Referring Expression Segmentation (GRES), presented in the paper [Unlocking the Potential of MLLMs in Referring Expression Segmentation via a Light-weight Mask Decoder](https://huggingface.co/papers/2508.04107).
|
10 |
+
|
11 |
+
**Abstract:** Reference Expression Segmentation (RES) aims to segment image regions specified by referring expressions and has become popular with the rise of multimodal large models (MLLMs). While MLLMs excel in semantic understanding, their token-generation paradigm struggles with pixel-level dense prediction. Existing RES methods either couple MLLMs with the parameter-heavy Segment Anything Model (SAM) with 632M network parameters or adopt SAM-free lightweight pipelines that sacrifice accuracy. To address the trade-off between performance and cost, we specifically propose MLLMSeg, a novel framework that fully exploits the inherent visual detail features encoded in the MLLM vision encoder without introducing an extra visual encoder. Besides, we propose a detail-enhanced and semantic-consistent feature fusion module (DSFF) that fully integrates the detail-related visual feature with the semantic-related feature output by the large language model (LLM) of MLLM. Finally, we establish a light-weight mask decoder with only 34M network parameters that optimally leverages detailed spatial features from the visual encoder and semantic features from the LLM to achieve precise mask prediction. Extensive experiments demonstrate that our method generally surpasses both SAM-based and SAM-free competitors, striking a better balance between performance and cost.
|
12 |
+
|
13 |
+
<p align="center">
|
14 |
+
<img src="https://github.com/jcwang0602/MLLMSeg/raw/main/assets/method.png" width="800">
|
15 |
+
</p>
|
16 |
+
|
17 |
+
## Paper and Code
|
18 |
+
* **Paper:** [Unlocking the Potential of MLLMs in Referring Expression Segmentation via a Light-weight Mask Decoder](https://huggingface.co/papers/2508.04107)
|
19 |
+
* **GitHub Repository:** [https://github.com/jcwang0602/MLLMSeg](https://github.com/jcwang0602/MLLMSeg)
|
20 |
+
|
21 |
+
## Usage
|
22 |
+
|
23 |
+
You can use the `MLLMSeg` model with the `transformers` library. The model takes an image and a referring expression as input and outputs a segmentation mask or coordinates. Our models accept images of any size as input. The model outputs are normalized to relative coordinates within a 0-1000 range (either a center point or a bounding box defined by top-left and bottom-right coordinates). For visualization, remember to convert these relative coordinates back to the original image dimensions.
|
24 |
+
|
25 |
+
### Installation
|
26 |
+
|
27 |
+
First, install the necessary dependencies:
|
28 |
+
|
29 |
+
```bash
|
30 |
+
conda create -n mllmseg python==3.10.18 -y
|
31 |
+
conda activate mllmseg
|
32 |
+
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu118
|
33 |
+
pip install -r requirements.txt
|
34 |
+
pip install flash-attn==2.3.6 --no-build-isolation # Note: need gpu to install
|
35 |
+
```
|
36 |
+
|
37 |
+
### Sample Usage
|
38 |
+
|
39 |
+
Here's a basic example demonstrating how to load and use the model for inference. We'll use the `MLLMSeg_InternVL2_5_8B_RES` model as an example.
|
40 |
+
|
41 |
+
```python
|
42 |
+
import torch
|
43 |
+
import torchvision.transforms as T
|
44 |
+
from PIL import Image
|
45 |
+
from torchvision.transforms.functional import InterpolationMode
|
46 |
+
from transformers import AutoModel, AutoTokenizer
|
47 |
+
import requests
|
48 |
+
from io import BytesIO
|
49 |
+
|
50 |
+
# --- Helper functions for image preprocessing (from original GitHub repo) ---
|
51 |
+
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
52 |
+
IMAGENET_STD = (0.229, 0.224, 0.225)
|
53 |
+
|
54 |
+
def build_transform(input_size):
|
55 |
+
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
|
56 |
+
transform = T.Compose([
|
57 |
+
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
|
58 |
+
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
|
59 |
+
T.ToTensor(),
|
60 |
+
T.Normalize(mean=MEAN, std=STD)
|
61 |
+
])
|
62 |
+
return transform
|
63 |
+
|
64 |
+
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
|
65 |
+
best_ratio_diff = float('inf')
|
66 |
+
best_ratio = (1, 1)
|
67 |
+
area = width * height
|
68 |
+
for ratio in target_ratios:
|
69 |
+
target_aspect_ratio = ratio[0] / ratio[1]
|
70 |
+
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
71 |
+
if ratio_diff < best_ratio_diff:
|
72 |
+
best_ratio_diff = ratio_diff
|
73 |
+
best_ratio = ratio
|
74 |
+
elif ratio_diff == best_ratio_diff:
|
75 |
+
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
76 |
+
best_ratio = ratio
|
77 |
+
return best_ratio
|
78 |
+
|
79 |
+
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=True):
|
80 |
+
orig_width, orig_height = image.size
|
81 |
+
aspect_ratio = orig_width / orig_height
|
82 |
+
|
83 |
+
target_ratios = set(
|
84 |
+
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
|
85 |
+
i * j <= max_num and i * j >= min_num)
|
86 |
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
87 |
+
|
88 |
+
target_aspect_ratio = find_closest_aspect_ratio(
|
89 |
+
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
|
90 |
+
|
91 |
+
target_width = image_size * target_aspect_ratio[0]
|
92 |
+
target_height = image_size * target_aspect_ratio[1]
|
93 |
+
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
94 |
+
|
95 |
+
resized_img = image.resize((target_width, target_height))
|
96 |
+
processed_images = []
|
97 |
+
for i in range(blocks):
|
98 |
+
box = (
|
99 |
+
(i % (target_width // image_size)) * image_size,
|
100 |
+
(i // (target_width // image_size)) * image_size,
|
101 |
+
((i % (target_width // image_size)) + 1) * image_size,
|
102 |
+
((i // (target_width // image_size)) + 1) * image_size
|
103 |
+
)
|
104 |
+
split_img = resized_img.crop(box)
|
105 |
+
processed_images.append(split_img)
|
106 |
+
assert len(processed_images) == blocks
|
107 |
+
if use_thumbnail and len(processed_images) != 1:
|
108 |
+
thumbnail_img = image.resize((image_size, image_size))
|
109 |
+
processed_images.append(thumbnail_img)
|
110 |
+
return processed_images
|
111 |
+
|
112 |
+
def load_image(image_file_or_url, input_size=448, max_num=6):
|
113 |
+
if isinstance(image_file_or_url, str) and image_file_or_url.startswith("http"):
|
114 |
+
response = requests.get(image_file_or_url, stream=True)
|
115 |
+
image = Image.open(BytesIO(response.content)).convert('RGB')
|
116 |
+
else:
|
117 |
+
image = Image.open(image_file_or_url).convert('RGB')
|
118 |
+
|
119 |
+
transform = build_transform(input_size=input_size)
|
120 |
+
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
|
121 |
+
pixel_values = [transform(image) for image in images]
|
122 |
+
pixel_values = torch.stack(pixel_values)
|
123 |
+
return pixel_values
|
124 |
+
# --- End of helper functions ---
|
125 |
+
|
126 |
+
# Load model and tokenizer
|
127 |
+
model_id = "jcwang0602/MLLMSeg_InternVL2_5_8B_RES"
|
128 |
+
model = AutoModel.from_pretrained(
|
129 |
+
model_id,
|
130 |
+
torch_dtype=torch.bfloat16,
|
131 |
+
low_cpu_mem_usage=True,
|
132 |
+
trust_remote_code=True
|
133 |
+
).eval().cuda()
|
134 |
+
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=False)
|
135 |
+
|
136 |
+
# Example image and question
|
137 |
+
# Using an example image from the MLLMSeg repository for demonstration
|
138 |
+
image_url = "https://huggingface.co/jcwang0602/MLLMSeg_InternVL2_5_8B_RES/resolve/main/assets/res_example.png"
|
139 |
+
question = "Please give me the segmentation mask of the dog (with [SEG])."
|
140 |
+
|
141 |
+
# Preprocess image
|
142 |
+
pixel_values = load_image(image_url, max_num=6).to(torch.bfloat16).cuda()
|
143 |
+
generation_config = dict(max_new_tokens=1024, do_sample=True)
|
144 |
+
|
145 |
+
# Generate response
|
146 |
+
response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=True)
|
147 |
+
print(f'User: {question}
|
148 |
+
Assistant: {response}')
|
149 |
+
|
150 |
+
# The output `response` will contain the segmentation information (e.g., coordinates or SEG token based output).
|
151 |
+
# You would then need to parse this string to extract the mask or coordinates for visualization.
|
152 |
+
```
|
153 |
+
|
154 |
+
## Checkpoints
|
155 |
+
|
156 |
+
Our checkpoints are available at:
|
157 |
+
|
158 |
+
| Base Model | RES Model | GRES Model |
|
159 |
+
|--------------|------------------------------------------------------------|------------------------------------------------------------|
|
160 |
+
| InternVL2_5_1B | [MLLMSeg_InternVL2_5_1B_RES](https://huggingface.co/jcwang0602/MLLMSeg_InternVL2_5_1B_RES) | - |
|
161 |
+
| InternVL2_5_2B | [MLLMSeg_InternVL2_5_2B_RES](https://huggingface.co/jcwang0602/MLLMSeg_InternVL2_5_2B_RES) | - |
|
162 |
+
| InternVL2_5_4B | [MLLMSeg_InternVL2_5_4B_RES](https://huggingface.co/jcwang0602/MLLMSeg_InternVL2_5_4B_RES) | - |
|
163 |
+
| InternVL2_5_8B | [MLLMSeg_InternVL2_5_8B_RES](https://huggingface.co/jcwang0602/MLLMSeg_InternVL2_5_8B_RES) | [MLLMSeg_InternVL2_5_8B_GRES](https://huggingface.co/jcwang0602/MLLMSeg_InternVL2_5_8B_GRES) |
|
164 |
+
|
165 |
+
## Performance Metrics
|
166 |
+
|
167 |
+
### Referring Expression Segmentation
|
168 |
+
<img src="https://github.com/jcwang0602/MLLMSeg/raw/main/assets/tab_res.png" width="800">
|
169 |
+
|
170 |
+
### Referring Expression Comprehension
|
171 |
+
<img src="https://github.com/jcwang0602/MLLMSeg/raw/main/assets/tab_rec.png" width="800">
|
172 |
+
|
173 |
+
### Generalized Referring Expression Segmentation
|
174 |
+
<img src="https://github.com/jcwang0602/MLLMSeg/raw/main/assets/tab_gres.png" width="800">
|
175 |
+
|
176 |
+
## Visualization
|
177 |
+
|
178 |
+
### Referring Expression Segmentation
|
179 |
+
<img src="https://github.com/jcwang0602/MLLMSeg/raw/main/assets/res.png" width="800">
|
180 |
+
|
181 |
+
### Referring Expression Comprehension
|
182 |
+
<img src="https://github.com/jcwang0602/MLLMSeg/raw/main/assets/rec.png" width="800">
|
183 |
+
|
184 |
+
### Generalized Referring Expression Segmentation
|
185 |
+
<img src="https://github.com/jcwang0602/MLLMSeg/raw/main/assets/gres.png" width="800">
|
186 |
+
|
187 |
+
## Citation
|
188 |
+
|
189 |
+
If our work is useful for your research, please consider citing:
|
190 |
+
|
191 |
+
```bibtex
|
192 |
+
@misc{wang2025unlockingpotentialmllmsreferring,
|
193 |
+
title={Unlocking the Potential of MLLMs in Referring Expression Segmentation via a Light-weight Mask Decoder},
|
194 |
+
author={Jingchao Wang and Zhijian Wu and Dingjiang Huang and Yefeng Zheng and Hong Wang},
|
195 |
+
year={2025},
|
196 |
+
eprint={2508.04107},
|
197 |
+
archivePrefix={arXiv},
|
198 |
+
primaryClass={cs.CV},
|
199 |
+
url={https://arxiv.org/abs/2508.04107},
|
200 |
+
}
|
201 |
+
```
|
202 |
+
|
203 |
+
## Acknowledgments
|
204 |
+
|
205 |
+
This code is developed on the top of [InternVL](https://github.com/OpenGVLab/InternVL), [GSVA](https://github.com/LeapLabTHU/GSVA), and [EEVG](https://github.com/chenwei746/EEVG).
|