license: mit
pipeline_tag: image-segmentation
library_name: transformers
MLLMSeg: Unlocking the Potential of MLLMs in Referring Expression Segmentation via a Light-weight Mask Decoder
This repository contains the MLLMSeg
model, a novel framework for Referring Expression Segmentation (RES) and Generalized Referring Expression Segmentation (GRES), presented in the paper Unlocking the Potential of MLLMs in Referring Expression Segmentation via a Light-weight Mask Decoder.
Abstract: Reference Expression Segmentation (RES) aims to segment image regions specified by referring expressions and has become popular with the rise of multimodal large models (MLLMs). While MLLMs excel in semantic understanding, their token-generation paradigm struggles with pixel-level dense prediction. Existing RES methods either couple MLLMs with the parameter-heavy Segment Anything Model (SAM) with 632M network parameters or adopt SAM-free lightweight pipelines that sacrifice accuracy. To address the trade-off between performance and cost, we specifically propose MLLMSeg, a novel framework that fully exploits the inherent visual detail features encoded in the MLLM vision encoder without introducing an extra visual encoder. Besides, we propose a detail-enhanced and semantic-consistent feature fusion module (DSFF) that fully integrates the detail-related visual feature with the semantic-related feature output by the large language model (LLM) of MLLM. Finally, we establish a light-weight mask decoder with only 34M network parameters that optimally leverages detailed spatial features from the visual encoder and semantic features from the LLM to achieve precise mask prediction. Extensive experiments demonstrate that our method generally surpasses both SAM-based and SAM-free competitors, striking a better balance between performance and cost.
Paper and Code
- Paper: Unlocking the Potential of MLLMs in Referring Expression Segmentation via a Light-weight Mask Decoder
- GitHub Repository: https://github.com/jcwang0602/MLLMSeg
Usage
You can use the MLLMSeg
model with the transformers
library. The model takes an image and a referring expression as input and outputs a segmentation mask or coordinates. Our models accept images of any size as input. The model outputs are normalized to relative coordinates within a 0-1000 range (either a center point or a bounding box defined by top-left and bottom-right coordinates). For visualization, remember to convert these relative coordinates back to the original image dimensions.
Installation
First, install the necessary dependencies:
conda create -n mllmseg python==3.10.18 -y
conda activate mllmseg
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu118
pip install -r requirements.txt
pip install flash-attn==2.3.6 --no-build-isolation # Note: need gpu to install
Sample Usage
Here's a basic example demonstrating how to load and use the model for inference. We'll use the MLLMSeg_InternVL2_5_8B_RES
model as an example.
import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
import requests
from io import BytesIO
# --- Helper functions for image preprocessing (from original GitHub repo) ---
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)
def build_transform(input_size):
MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
transform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
T.ToTensor(),
T.Normalize(mean=MEAN, std=STD)
])
return transform
def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
best_ratio_diff = float('inf')
best_ratio = (1, 1)
area = width * height
for ratio in target_ratios:
target_aspect_ratio = ratio[0] / ratio[1]
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
if ratio_diff < best_ratio_diff:
best_ratio_diff = ratio_diff
best_ratio = ratio
elif ratio_diff == best_ratio_diff:
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
best_ratio = ratio
return best_ratio
def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=True):
orig_width, orig_height = image.size
aspect_ratio = orig_width / orig_height
target_ratios = set(
(i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
i * j <= max_num and i * j >= min_num)
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
target_aspect_ratio = find_closest_aspect_ratio(
aspect_ratio, target_ratios, orig_width, orig_height, image_size)
target_width = image_size * target_aspect_ratio[0]
target_height = image_size * target_aspect_ratio[1]
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
resized_img = image.resize((target_width, target_height))
processed_images = []
for i in range(blocks):
box = (
(i % (target_width // image_size)) * image_size,
(i // (target_width // image_size)) * image_size,
((i % (target_width // image_size)) + 1) * image_size,
((i // (target_width // image_size)) + 1) * image_size
)
split_img = resized_img.crop(box)
processed_images.append(split_img)
assert len(processed_images) == blocks
if use_thumbnail and len(processed_images) != 1:
thumbnail_img = image.resize((image_size, image_size))
processed_images.append(thumbnail_img)
return processed_images
def load_image(image_file_or_url, input_size=448, max_num=6):
if isinstance(image_file_or_url, str) and image_file_or_url.startswith("http"):
response = requests.get(image_file_or_url, stream=True)
image = Image.open(BytesIO(response.content)).convert('RGB')
else:
image = Image.open(image_file_or_url).convert('RGB')
transform = build_transform(input_size=input_size)
images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
pixel_values = [transform(image) for image in images]
pixel_values = torch.stack(pixel_values)
return pixel_values
# --- End of helper functions ---
# Load model and tokenizer
model_id = "jcwang0602/MLLMSeg_InternVL2_5_8B_RES"
model = AutoModel.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
trust_remote_code=True
).eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=False)
# Example image and question
# Using an example image from the MLLMSeg repository for demonstration
image_url = "https://huggingface.co/jcwang0602/MLLMSeg_InternVL2_5_8B_RES/resolve/main/assets/res_example.png"
question = "Please give me the segmentation mask of the dog (with [SEG])."
# Preprocess image
pixel_values = load_image(image_url, max_num=6).to(torch.bfloat16).cuda()
generation_config = dict(max_new_tokens=1024, do_sample=True)
# Generate response
response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=True)
print(f'User: {question}
Assistant: {response}')
# The output `response` will contain the segmentation information (e.g., coordinates or SEG token based output).
# You would then need to parse this string to extract the mask or coordinates for visualization.
Checkpoints
Our checkpoints are available at:
Base Model | RES Model | GRES Model |
---|---|---|
InternVL2_5_1B | MLLMSeg_InternVL2_5_1B_RES | - |
InternVL2_5_2B | MLLMSeg_InternVL2_5_2B_RES | - |
InternVL2_5_4B | MLLMSeg_InternVL2_5_4B_RES | - |
InternVL2_5_8B | MLLMSeg_InternVL2_5_8B_RES | MLLMSeg_InternVL2_5_8B_GRES |
Performance Metrics
Referring Expression Segmentation

Referring Expression Comprehension

Generalized Referring Expression Segmentation

Visualization
Referring Expression Segmentation

Referring Expression Comprehension

Generalized Referring Expression Segmentation

Citation
If our work is useful for your research, please consider citing:
@misc{wang2025unlockingpotentialmllmsreferring,
title={Unlocking the Potential of MLLMs in Referring Expression Segmentation via a Light-weight Mask Decoder},
author={Jingchao Wang and Zhijian Wu and Dingjiang Huang and Yefeng Zheng and Hong Wang},
year={2025},
eprint={2508.04107},
archivePrefix={arXiv},
primaryClass={cs.CV},
url={https://arxiv.org/abs/2508.04107},
}
Acknowledgments
This code is developed on the top of InternVL, GSVA, and EEVG.