---
library_name: transformers
license: apache-2.0
tags: []
pipeline_tag: audio-text-to-text
---

# R1-AQA --- Reinforcement Learning Outperforms Supervised Fine-Tuning: A Case Study on Audio Question Answering

<!-- Provide a quick summary of what the model is/does. -->

## Introduction

R1-AQA is a audio question answering (AQA) model based on `Qwen2-Audio-7B-Instruct`, optimized through reinforcement learning using the group relative policy optimization (GRPO) algorithm. 
This implementation has achieved state-of-the-art performance on MMAU *Test-mini* benchmark with only 38k post-training samples.
For more details, please refer to our [Github](https://github.com/xiaomi/r1-aqa) and [Technical Report](https://arxiv.org/abs/2503.11197).

### Table: Accuracies (%) on MMAU Test-mini benchmark
| Model                                      | Method                  | Sound  | Music  | Speech | Average |
|--------------------------------------------|-------------------------|--------|--------|--------|---------|
| \                                          | Human\*                 | 86.31  | 78.22  | 82.17  | 82.23   |
| Gemini Pro 2.0 Flash                       | Direct Inference\*      | 56.46  | 58.68  | 51.65  | 55.60   |
| Audio Flamingo 2                           | Direct Inference\*      | 61.56  | **73.95** | 30.93  | 55.48   |
| GPT4o + Strong Cap.                        | Direct Inference\*      | 57.35  | 49.70  | **64.86** | 57.30   |
| Llama-3-8B-Instruct + Strong Cap.          | Direct Inference\*      | 50.75  | 48.93  | 55.25  | 52.10   |
| Gemini Pro v1.5                            | Direct Inference\*      | 56.75  | 49.40  | 58.55  | 54.90   |
| Qwen2-Audio-7B-Instruct                    | Direct Inference\*      | 54.95  | 50.98  | 42.04  | 49.20   |
| GPT4o + Weak Cap.                          | Direct Inference\*      | 39.33  | 41.90  | 58.25  | 45.70   |
| Llama-3-8B-Instruct + Weak Cap.            | Direct Inference\*      | 34.23  | 38.02  | 54.05  | 42.10   |
| SALMONN                                    | Direct Inference\*      | 41.00  | 34.80  | 25.50  | 33.70   |
| Qwen2-Audio-7B-Instruct                    | CoTA \[1\]            | 60.06  | 64.30  | 60.70  | 61.71   |
| Qwen2-Audio-7B-Instruct                    | Zero-Shot-CoT \[2\]   | 61.86  | 56.29  | 55.26  | 57.80   |
| **Qwen2-Audio-7B-Instruct**                | **GRPO (Ours)**         | **69.37** | 66.77  | 57.36  | **64.50** |

#### Notes:
\* The data are sourced from the MMAU official website: [https://sakshi113.github.io/mmau_homepage/](https://sakshi113.github.io/mmau_homepage/)  
\[1\] Xie, Zhifei, et al. "Audio-Reasoner: Improving Reasoning Capability in Large Audio Language Models." arXiv preprint arXiv:2503.02318 (2025).  
\[2\] Ma, Ziyang, et al. "Audio-CoT: Exploring Chain-of-Thought Reasoning in Large Audio Language Model." arXiv preprint arXiv:2501.07246 (2025).  



## Inference
```python
import torch
import torchaudio
from transformers import Qwen2AudioForConditionalGeneration, AutoProcessor

# Load model
model_name = "mispeech/r1-aqa"
processor = AutoProcessor.from_pretrained(model_name)
model = Qwen2AudioForConditionalGeneration.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")

# Load example audio
wav_path = "test-mini-audios/3fe64f3d-282c-4bc8-a753-68f8f6c35652.wav"  # from MMAU dataset
waveform, sampling_rate = torchaudio.load(wav_path)
assert sampling_rate == 16000
audios = [waveform.numpy()]

# Make prompt text
question = "Based on the given audio, identify the source of the speaking voice."
options = ["Man", "Woman", "Child", "Robot"]
prompt = f"{question} Please choose the answer from the following options: {str(options)}. Output the final answer in <answer> </answer>."
message = [
    {"role": "user", "content": [
        {"type": "audio", "audio_url": wav_path},
        {"type": "text", "text": prompt}
    ]}
]
texts = processor.apply_chat_template(message, add_generation_prompt=True, tokenize=False)

# Process
inputs = processor(text=texts, audios=audios, sampling_rate=16000, return_tensors="pt", padding=True).to(model.device)
generated_ids = model.generate(**inputs, max_new_tokens=256)
generated_ids = generated_ids[:, inputs.input_ids.size(1):]
response = processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)

print(response)
```