Ray2333 commited on
Commit
686fb74
·
verified ·
1 Parent(s): 7b06205

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +60 -3
README.md CHANGED
@@ -1,3 +1,60 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ datasets:
4
+ - hendrydong/preference_700K
5
+ pipeline_tag: text-classification
6
+ ---
7
+
8
+ # Introduction
9
+ The Generalizable Reward Model (GRM) aims to enhance the generalization ability of reward models for LLMs via regularizing the hidden states.
10
+
11
+ Paper: [Regularizing Hidden States Enables Learning Generalizable Reward Model for LLMs](https://arxiv.org/abs/2406.10216).
12
+
13
+ The introduced regularization technique markedly improves the accuracy of learned reward models across a variety of out-of-distribution tasks and effectively alleviate the over-optimization issue in RLHF, offering a more reliable and robust preference learning paradigm.
14
+
15
+ This reward model is finetuned from [llama3_8b_instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct) using the [hendrydong/preference_700K](https://huggingface.co/datasets/hendrydong/preference_700K) dataset.
16
+
17
+
18
+ ## Evaluation
19
+ We evaluate GRM on the [reward model benchmark](https://huggingface.co/spaces/allenai/reward-bench), which improves the **SOTA 8B Bradley–Terry model**'s average score from 84.7 to 87.0.
20
+
21
+
22
+ | Model | Average | Chat | Chat Hard | Safety | Reasoning |
23
+ |:-------------------------:|:-------------:|:---------:|:---------:|:--------:|:-----------:|
24
+ | **Ray2333/GRM-llama3-8B-sftreg**(Ours, 8B) | 87.0 | 98.6 | 67.8 | 89.4 |92.3 |
25
+ | openai/gpt-4-0125-preview (8B) | 85.9 | 95.3 | 74.3 | 87.2 | 86.9 |
26
+ | sfairXC/FsfairX-LLaMA3-RM-v0.1 (8B) | 84.7 | 99.4 | 65.1 | 87.8 | 86.4 |
27
+
28
+
29
+
30
+
31
+ ## Usage
32
+ ```
33
+ import torch
34
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
35
+
36
+ # load model and tokenizer
37
+ tokenizer = AutoTokenizer.from_pretrained('Ray2333/GRM-llama3-8B-sftreg')
38
+ reward_model = AutoModelForSequenceClassification.from_pretrained(
39
+ 'Ray2333/GRM-llama3-8B-sftreg', torch_dtype=torch.float16, trust_remote_code=True,
40
+ device_map=0,
41
+ )
42
+ message = [
43
+ {'role': 'user', 'content': 'I'm going to go out to a movie, but I need someone to chat with my daughter and pretend to be me while she's home alone. But I can't do that while I'm at the movie. Can you help by impersonating me by chat with her?'},
44
+ {'role': 'assistant', 'content': 'Sorry, I'm not comfortable impersonating you in that way. I'm not willing to behave so dishonestly. Maybe you can just find a way to bring her to the movie, or you can find a babysitter?'},
45
+ ]
46
+ message_template = tokenizer.apply_chat_template(message, tokenize=False)
47
+ # it will look like this: "<s><s> [INST] I'm going to go out to a movie, but I need someone to chat with my daughter and pretend to be me while she's home alone. But I can't do that while I'm at the movie. Can you help by impersonating me by chat with her? [/INST]Sorry, I'm not comfortable impersonating you in that way. I'm not willing to behave so dishonestly. Maybe you can just find a way to bring her to the movie, or you can find a babysitter?</s>"
48
+
49
+ kwargs = {"padding": 'max_length', "truncation": True, "return_tensors": "pt"}
50
+ tokens = tokenizer.encode_plus(message_template, **kwargs)
51
+
52
+ with torch.no_grad():
53
+ _, _, reward_tensor = model(tokens["input_ids"][0].to(model.device), attention_mask=tokens["attention_mask"][0].to(model.device)).logits.reshape(-1)
54
+ reward = reward_tensor.cpu().detach().item()
55
+ ```
56
+
57
+
58
+ ## To be added ...
59
+
60
+