MLLMSeg: Unlocking the Potential of MLLMs in Referring Expression Segmentation via a Light-weight Mask Decoder

This repository contains the MLLMSeg model, a novel framework for Referring Expression Segmentation (RES) and Generalized Referring Expression Segmentation (GRES), presented in the paper Unlocking the Potential of MLLMs in Referring Expression Segmentation via a Light-weight Mask Decoder.

Abstract: Reference Expression Segmentation (RES) aims to segment image regions specified by referring expressions and has become popular with the rise of multimodal large models (MLLMs). While MLLMs excel in semantic understanding, their token-generation paradigm struggles with pixel-level dense prediction. Existing RES methods either couple MLLMs with the parameter-heavy Segment Anything Model (SAM) with 632M network parameters or adopt SAM-free lightweight pipelines that sacrifice accuracy. To address the trade-off between performance and cost, we specifically propose MLLMSeg, a novel framework that fully exploits the inherent visual detail features encoded in the MLLM vision encoder without introducing an extra visual encoder. Besides, we propose a detail-enhanced and semantic-consistent feature fusion module (DSFF) that fully integrates the detail-related visual feature with the semantic-related feature output by the large language model (LLM) of MLLM. Finally, we establish a light-weight mask decoder with only 34M network parameters that optimally leverages detailed spatial features from the visual encoder and semantic features from the LLM to achieve precise mask prediction. Extensive experiments demonstrate that our method generally surpasses both SAM-based and SAM-free competitors, striking a better balance between performance and cost.

Paper and Code

Usage

You can use the MLLMSeg model with the transformers library. The model takes an image and a referring expression as input and outputs a segmentation mask or coordinates. Our models accept images of any size as input. The model outputs are normalized to relative coordinates within a 0-1000 range (either a center point or a bounding box defined by top-left and bottom-right coordinates). For visualization, remember to convert these relative coordinates back to the original image dimensions.

Installation

First, install the necessary dependencies:

conda create -n mllmseg python==3.10.18 -y
conda activate mllmseg
pip install torch==2.5.1 torchvision==0.20.1 --index-url https://download.pytorch.org/whl/cu118
pip install -r requirements.txt
pip install flash-attn==2.3.6 --no-build-isolation # Note: need gpu to install

Sample Usage

Here's a basic example demonstrating how to load and use the model for inference. We'll use the MLLMSeg_InternVL2_5_8B_RES model as an example.

import torch
import torchvision.transforms as T
from PIL import Image
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
import requests
from io import BytesIO

# --- Helper functions for image preprocessing (from original GitHub repo) ---
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

def build_transform(input_size):
    MEAN, STD = IMAGENET_MEAN, IMAGENET_STD
    transform = T.Compose([
        T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
        T.Resize((input_size, input_size), interpolation=InterpolationMode.BICUBIC),
        T.ToTensor(),
        T.Normalize(mean=MEAN, std=STD)
    ])
    return transform

def find_closest_aspect_ratio(aspect_ratio, target_ratios, width, height, image_size):
    best_ratio_diff = float('inf')
    best_ratio = (1, 1)
    area = width * height
    for ratio in target_ratios:
        target_aspect_ratio = ratio[0] / ratio[1]
        ratio_diff = abs(aspect_ratio - target_aspect_ratio)
        if ratio_diff < best_ratio_diff:
            best_ratio_diff = ratio_diff
            best_ratio = ratio
        elif ratio_diff == best_ratio_diff:
            if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
                best_ratio = ratio
    return best_ratio

def dynamic_preprocess(image, min_num=1, max_num=6, image_size=448, use_thumbnail=True):
    orig_width, orig_height = image.size
    aspect_ratio = orig_width / orig_height

    target_ratios = set(
        (i, j) for n in range(min_num, max_num + 1) for i in range(1, n + 1) for j in range(1, n + 1) if
        i * j <= max_num and i * j >= min_num)
    target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])

    target_aspect_ratio = find_closest_aspect_ratio(
        aspect_ratio, target_ratios, orig_width, orig_height, image_size)

    target_width = image_size * target_aspect_ratio[0]
    target_height = image_size * target_aspect_ratio[1]
    blocks = target_aspect_ratio[0] * target_aspect_ratio[1]

    resized_img = image.resize((target_width, target_height))
    processed_images = []
    for i in range(blocks):
        box = (
            (i % (target_width // image_size)) * image_size,
            (i // (target_width // image_size)) * image_size,
            ((i % (target_width // image_size)) + 1) * image_size,
            ((i // (target_width // image_size)) + 1) * image_size
        )
        split_img = resized_img.crop(box)
        processed_images.append(split_img)
    assert len(processed_images) == blocks
    if use_thumbnail and len(processed_images) != 1:
        thumbnail_img = image.resize((image_size, image_size))
        processed_images.append(thumbnail_img)
    return processed_images

def load_image(image_file_or_url, input_size=448, max_num=6):
    if isinstance(image_file_or_url, str) and image_file_or_url.startswith("http"):
        response = requests.get(image_file_or_url, stream=True)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file_or_url).convert('RGB')

    transform = build_transform(input_size=input_size)
    images = dynamic_preprocess(image, image_size=input_size, use_thumbnail=True, max_num=max_num)
    pixel_values = [transform(image) for image in images]
    pixel_values = torch.stack(pixel_values)
    return pixel_values
# --- End of helper functions ---

# Load model and tokenizer
model_id = "jcwang0602/MLLMSeg_InternVL2_5_8B_RES"
model = AutoModel.from_pretrained(
    model_id,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    trust_remote_code=True
).eval().cuda()
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True, use_fast=False)

# Example image and question
# Using an example image from the MLLMSeg repository for demonstration
image_url = "https://huggingface.co/jcwang0602/MLLMSeg_InternVL2_5_8B_RES/resolve/main/assets/res_example.png"
question = "Please give me the segmentation mask of the dog (with [SEG])."

# Preprocess image
pixel_values = load_image(image_url, max_num=6).to(torch.bfloat16).cuda()
generation_config = dict(max_new_tokens=1024, do_sample=True)

# Generate response
response, history = model.chat(tokenizer, pixel_values, question, generation_config, history=None, return_history=True)
print(f'User: {question}
Assistant: {response}')

# The output `response` will contain the segmentation information (e.g., coordinates or SEG token based output).
# You would then need to parse this string to extract the mask or coordinates for visualization.

Checkpoints

Our checkpoints are available at:

Base Model RES Model GRES Model
InternVL2_5_1B MLLMSeg_InternVL2_5_1B_RES -
InternVL2_5_2B MLLMSeg_InternVL2_5_2B_RES -
InternVL2_5_4B MLLMSeg_InternVL2_5_4B_RES -
InternVL2_5_8B MLLMSeg_InternVL2_5_8B_RES MLLMSeg_InternVL2_5_8B_GRES

Performance Metrics

Referring Expression Segmentation

Referring Expression Comprehension

Generalized Referring Expression Segmentation

Visualization

Referring Expression Segmentation

Referring Expression Comprehension

Generalized Referring Expression Segmentation

Citation

If our work is useful for your research, please consider citing:

@misc{wang2025unlockingpotentialmllmsreferring,
      title={Unlocking the Potential of MLLMs in Referring Expression Segmentation via a Light-weight Mask Decoder},
      author={Jingchao Wang and Zhijian Wu and Dingjiang Huang and Yefeng Zheng and Hong Wang},
      year={2025},
      eprint={2508.04107},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2508.04107},
}

Acknowledgments

This code is developed on the top of InternVL, GSVA, and EEVG.

Downloads last month
30
Safetensors
Model size
8.11B params
Tensor type
BF16
·
F16
·
Inference Providers NEW
This model isn't deployed by any Inference Provider. 🙋 Ask for provider support