Quantile Regression for Distributional Reward Models in RLHF

(This is an old version. The new one trained on the decontaminated version of the Skywork dataset is nicolinho/QRM-Llama3.1-8B-v2)

Demo Code

import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
device = "cuda"
path = "nicolinho/QRM-Llama3-8B"
model = AutoModelForSequenceClassification.from_pretrained(path, device_map=device, 
                               trust_remote_code=True, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True)
# We load a random sample from the validation set of the HelpSteer dataset
prompt = 'Does pineapple belong on a Pizza?'
response = "There are different opinions on this. Some people like pineapple on a Pizza while others condemn this."
messages = [{"role": "user", "content": prompt},
           {"role": "assistant", "content": response}]
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
with torch.no_grad():
   output = model(input_ids)
   # Expectation of the reward distribution
   reward = output.score.cpu().float() 
   # Quantile estimates for the quantiles 0.05, 0.1, ..., 0.9, 0.95 representing the distribution over rewards
   reward_quantiles = output.reward_quantiles.cpu().float()

# The attributes of the 19 reward objectives
attributes = ['helpsteer-helpfulness','helpsteer-correctness','helpsteer-coherence',
   'helpsteer-complexity','helpsteer-verbosity','ultrafeedback-overall_score',
   'ultrafeedback-instruction_following', 'ultrafeedback-truthfulness',
   'ultrafeedback-honesty','ultrafeedback-helpfulness','beavertails-is_safe',
   'prometheus-score','argilla-overall_quality','argilla-judge_lm','code-complexity',
   'code-style','code-explanation','code-instruction-following','code-readability']

Citation

If you find this work useful for your research, please consider citing:

@article{dorka2024quantile,
  title={Quantile Regression for Distributional Reward Models in RLHF},
  author={Dorka, Nicolai},
  journal={arXiv preprint arXiv:2409.10164},
  year={2024}
}
Downloads last month
43
Safetensors
Model size
7.51B params
Tensor type
BF16
ยท
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Space using nicolinho/QRM-Llama3-8B 1