RadVLM / README.md
NicoZenith's picture
Update README.md
8df2f16 verified
metadata
license: apache-2.0
library_name: transformers

RadVLM Model Card

A Multitask Conversational Vision-Language Model for Radiology (paper: https://arxiv.org/abs/2502.03333)

Model Development

  • Developed by: KrauthammerLab, University of Zurich, ETH Zurich, Kyoto University of Applied Science, Kobe University, Swiss AI Initiative
  • Contributors: Nicolas Deperrois, Hidetoshi Matsuo, Samuel Ruipérez-Campillo, Moritz Vandenhirtz, Sonia Laguna, Alain Ryser, Koji Fujimoto, Mizuho Nishio, Thomas M. Sutter, Julia E. Vogt, Jonas Kluckert, Thomas Frauenfelder, Christian Blüthgen, Farhad Nooralahzadeh, Michael Krauthammer

Model Overview

RadVLM is a compact, multitask vision-language model designed for conversational Chest X-ray (CXR) interpretation. Unlike traditional models focused solely on report generation, RadVLM supports interactive, multi-turn diagnostic conversations. It has been fine-tuned on a large-scale instruction dataset containing over 1 million image-instruction pairs, covering tasks such as abnormality classification, visual grounding, and structured conversations.

Intended Use

  • *Primary Use Cases
    • Diagnostic Assistance: Providing conversational interpretations of CXRs to assist clinicians in reviewing images.
    • Medical Education: Supporting radiology trainees in learning CXR interpretation through interactive Q&A.
    • Preliminary Findings: Generating structured observations from CXRs to complement radiology reports.
  • Out-of-Scope Uses
    • Clinical Decision Making: RadVLM is not a replacement for a licensed radiologist and should not be used as the sole basis for medical decisions.
    • Automated Diagnosis: The model does not provide definitive diagnoses and should be used as a supplementary tool.
    • Use Outside of CXR Interpretation: The model has been trained specifically for Chest X-rays and is not designed for other medical imaging modalities.

Inputs and Outputs

  • Input:
    • Image: A frontal Chest X-ray (PIL Image or NumPy array).
    • Text: A user prompt (free-text query about the image).
    • Chat History (optional): Multi-turn interaction history.
  • Output:
    • Text Response: A natural language answer to the user's query.
    • Bounding Boxes (if applicable): Coordinates indicating the location of anatomical structures or abnormalities.

Model Architecture

  • Backbone: LLaVA-OneVision-7B (https://huggingface.co/llava-hf/llava-onevision-qwen2-7b-si-hf), a vision-language model adapted for medical tasks.
  • Vision Encoder: SigLIP, used for image feature extraction.
  • Instruction Tuning: Fine-tuned with multi-task objectives, covering report generation, abnormality detection, and multi-turn Q&A.

Training Data

RadVLM was trained on a large-scale instruction dataset derived from publicly available medical sources:

  • MIMIC-CXR: Radiology reports paired with images.
  • CheXpert: Abnormality classification labels.
  • VinDr-CXR: Manually annotated abnormality locations.
  • Chest Imagenome: Bounding boxes for anatomical regions.
  • MS-CXR & PadChest-GR: Phrase grounding data. All data sources were de-identified and anonymized prior to use.

Dependencies

pip install torch torchvision
pip install transformers==4.46.0

Inference function

Below is the inference_radvlm function that facilitates multi-turn interactions with the model. This function handles both single-turn and multi-turn conversations, managing the chat history to maintain context across multiple exchanges.

import requests
from PIL import Image
from numpy import asarray
import torch
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
import re

def inference_radvlm(model, processor, image, prompt, chat_history=None, max_new_tokens=1500):
    """
    Generate a response using RadVLM in either single-turn or multi-turn mode.

    Args:
        model: The RadVLM model.
        processor: The processor for RadVLM (provides apply_chat_template and tokenization).
        image: A PIL Image or NumPy array representing the input image.
        prompt: The user prompt for this turn.
        chat_history: A list of (user_msg, assistant_msg) tuples representing the conversation so far.
                      If None or empty, single-turn mode is used. Even in single-turn mode, 
                      this function returns chat_history so that you can continue in subsequent turns.
        max_new_tokens: The maximum number of new tokens to generate.

    Returns:
        response (str): The assistant's response for this turn.
        chat_history (list): The updated chat_history including this turn's (prompt, response).
    """

    # Initialize chat history if not provided
    if chat_history is None:
        chat_history = []

    # Build the chat history 
    conversation = []
    for idx, (user_text, assistant_text) in enumerate(chat_history):
        if idx == 0:
            conversation.append({
                "role": "user",
                "content": [
                    {"type": "text", "text": user_text},
                    {"type": "image"},
                ],
            })
        else:
            conversation.append({
                "role": "user",
                "content": [
                    {"type": "text", "text": user_text},
                ],
            })
        conversation.append({
            "role": "assistant",
            "content": [
                {"type": "text", "text": assistant_text},
            ],
        })

    # Add the current user prompt
    if len(chat_history) == 0:
        # First turn includes the image
        conversation.append({
            "role": "user",
            "content": [
                {"type": "text", "text": prompt},
                {"type": "image"},
            ],
        })
    else:
        # Subsequent turns without the image
        conversation.append({
            "role": "user",
            "content": [{"type": "text", "text": prompt}],
        })

    # Apply the chat template to create the full prompt
    full_prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)

    # Prepare model inputs
    inputs = processor(images=image, text=full_prompt, return_tensors="pt", padding=True).to(
        model.device, torch.float16
    )

    # Generate the response
    with torch.inference_mode():
        output = model.generate(**inputs, max_new_tokens=max_new_tokens, do_sample=False)

    # Decode the output
    full_response = processor.decode(output[0], skip_special_tokens=True)
    response = re.split(r"(user|assistant)", full_response)[-1].strip()

    # Update chat history
    chat_history.append((prompt, response))

    return response, chat_history

Quick-Start: Multi-turn Demo

Below is a demonstration of how to utilize the inference_radvlm function in a multi-turn conversation.

import torch
from transformers import AutoProcessor, LlavaOnevisionForConditionalGeneration
from PIL import Image
import requests
from io import BytesIO
import numpy as np

 Initialize the model and processor
model_id = "KrauthammerLab/RadVLM"
model = LlavaOnevisionForConditionalGeneration.from_pretrained(
    model_id, 
    torch_dtype=torch.float16, 
    low_cpu_mem_usage=True, 
).to('cuda')  # Use 'cuda' if GPU is available, else 'cpu'

processor = AutoProcessor.from_pretrained(model_id)

image_url = "https://prod-images-static.radiopaedia.org/images/29923576/fed73420497c8622734f21ce20fc91_gallery.jpeg"
image = Image.open(requests.get(image_url, stream=True).raw).convert("RGB")

# Initialize chat history
chat_history = []

# First user prompt with image from URL
user_prompt_1 = "What can you say about this X-ray?"
response_1, chat_history = inference_radvlm(model, processor, image, user_prompt_1, chat_history)

print("RadVLM:", response_1)

# Second user prompt, continuing the conversation
user_prompt_2 = "Is there something concerning in the lungs area?"
response_2, chat_history = inference_radvlm(model, processor, image, user_prompt_2, chat_history)

print("RadVLM:", response_2)

# Third user prompt
user_prompt_3 = "What about the cardiac silhouette? Is it normal?"
response_3, chat_history = inference_radvlm(model, processor, image, user_prompt_3, chat_history)

print("Assistant:", response_3)

References

For reference, please use the following:

@misc{deperrois2025radvlmmultitaskconversationalvisionlanguage,
      title={RadVLM: A Multitask Conversational Vision-Language Model for Radiology}, 
      author={Nicolas Deperrois and Hidetoshi Matsuo and Samuel Ruipérez-Campillo and Moritz Vandenhirtz and Sonia Laguna and Alain Ryser and Koji Fujimoto and Mizuho Nishio and Thomas M. Sutter and Julia E. Vogt and Jonas Kluckert and Thomas Frauenfelder and Christian Blüthgen and Farhad Nooralahzadeh and Michael Krauthammer},
      year={2025},
      eprint={2502.03333},
      archivePrefix={arXiv},
      primaryClass={cs.CV},
      url={https://arxiv.org/abs/2502.03333}, 
}