|
--- |
|
library_name: transformers |
|
license: apache-2.0 |
|
tags: [] |
|
pipeline_tag: audio-text-to-text |
|
--- |
|
|
|
# R1-AQA --- Reinforcement Learning Outperforms Supervised Fine-Tuning: A Case Study on Audio Question Answering |
|
|
|
<!-- Provide a quick summary of what the model is/does. --> |
|
|
|
## Introduction |
|
|
|
R1-AQA is a audio question answering (AQA) model based on `Qwen2-Audio-7B-Instruct`, optimized through reinforcement learning using the group relative policy optimization (GRPO) algorithm. |
|
This implementation has achieved state-of-the-art performance on MMAU *Test-mini* benchmark with only 38k post-training samples. |
|
For more details, please refer to our [Github](https://github.com/xiaomi-research/r1-aqa) and [Technical Report](https://arxiv.org/abs/2503.11197). |
|
|
|
### Table: Accuracies (%) on MMAU Test-mini benchmark |
|
| Model | Method | Sound | Music | Speech | Average | |
|
|--------------------------------------------|-------------------------|--------|--------|--------|---------| |
|
| \ | Human\* | 86.31 | 78.22 | 82.17 | 82.23 | |
|
| Gemini Pro 2.0 Flash | Direct Inference\* | 56.46 | 58.68 | 51.65 | 55.60 | |
|
| Audio Flamingo 2 | Direct Inference\* | 61.56 | **73.95** | 30.93 | 55.48 | |
|
| GPT4o + Strong Cap. | Direct Inference\* | 57.35 | 49.70 | **64.86** | 57.30 | |
|
| Llama-3-8B-Instruct + Strong Cap. | Direct Inference\* | 50.75 | 48.93 | 55.25 | 52.10 | |
|
| Gemini Pro v1.5 | Direct Inference\* | 56.75 | 49.40 | 58.55 | 54.90 | |
|
| Qwen2-Audio-7B-Instruct | Direct Inference\* | 54.95 | 50.98 | 42.04 | 49.20 | |
|
| GPT4o + Weak Cap. | Direct Inference\* | 39.33 | 41.90 | 58.25 | 45.70 | |
|
| Llama-3-8B-Instruct + Weak Cap. | Direct Inference\* | 34.23 | 38.02 | 54.05 | 42.10 | |
|
| SALMONN | Direct Inference\* | 41.00 | 34.80 | 25.50 | 33.70 | |
|
| Qwen2-Audio-7B-Instruct | CoTA \[1\] | 60.06 | 64.30 | 60.70 | 61.71 | |
|
| Qwen2-Audio-7B-Instruct | Zero-Shot-CoT \[2\] | 61.86 | 56.29 | 55.26 | 57.80 | |
|
| **Qwen2-Audio-7B-Instruct** | **GRPO (Ours)** | **69.37** | 66.77 | 57.36 | **64.50** | |
|
|
|
#### Notes: |
|
\* The data are sourced from the MMAU official website: [https://sakshi113.github.io/mmau_homepage/](https://sakshi113.github.io/mmau_homepage/) |
|
\[1\] Xie, Zhifei, et al. "Audio-Reasoner: Improving Reasoning Capability in Large Audio Language Models." arXiv preprint arXiv:2503.02318 (2025). |
|
\[2\] Ma, Ziyang, et al. "Audio-CoT: Exploring Chain-of-Thought Reasoning in Large Audio Language Model." arXiv preprint arXiv:2501.07246 (2025). |
|
|
|
|
|
|
|
## Inference |
|
```python |
|
import torch |
|
import torchaudio |
|
from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor |
|
|
|
# Load model |
|
model_name = "mispeech/r1-aqa" |
|
processor = AutoProcessor.from_pretrained(model_name) |
|
model = Qwen2AudioForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto") |
|
|
|
# Load example audio |
|
wav_path = "test-mini-audios/3fe64f3d-282c-4bc8-a753-68f8f6c35652.wav" # from MMAU dataset |
|
waveform, sampling_rate = torchaudio.load(wav_path) |
|
if sampling_rate != 16000: |
|
waveform = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=16000)(waveform) |
|
audios = [waveform[0].numpy()] |
|
|
|
# Make prompt text |
|
question = "Based on the given audio, identify the source of the speaking voice." |
|
options = ["Man", "Woman", "Child", "Robot"] |
|
prompt = f"{question} Please choose the answer from the following options: {str(options)}. Output the final answer in <answer> </answer>." |
|
message = [ |
|
{"role": "user", "content": [ |
|
{"type": "audio", "audio_url": wav_path}, |
|
{"type": "text", "text": prompt} |
|
]} |
|
] |
|
texts = processor.apply_chat_template(message, add_generation_prompt=True, tokenize=False) |
|
|
|
# Process |
|
inputs = processor(text=texts, audios=audios, sampling_rate=16000, return_tensors="pt", padding=True).to(model.device) |
|
generated_ids = model.generate(**inputs, max_new_tokens=256) |
|
generated_ids = generated_ids[:, inputs.input_ids.size(1):] |
|
response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False) |
|
|
|
print(response) |
|
``` |