File size: 4,820 Bytes
d11b63f
3a01517
d11b63f
ec9b1de
 
d11b63f
1748050
d11b63f
739e21a
d11b63f
 
739e21a
d11b63f
1748050
739e21a
ec9b1de
d11b63f
05a9ebf
 
 
739e21a
d11b63f
 
 
 
739e21a
 
 
d11b63f
1748050
3a01517
d11b63f
 
1748050
d11b63f
 
 
 
 
1748050
d11b63f
739e21a
 
d11b63f
 
 
 
 
 
 
 
 
1748050
d11b63f
 
 
 
1748050
d11b63f
1748050
 
 
 
 
 
 
 
 
 
 
 
 
df5f30b
 
1748050
 
 
 
 
 
 
 
 
 
 
d11b63f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
739e21a
d11b63f
 
1748050
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer
from sarm_llama import LlamaSARM

# --- 1. Load Model and Tokenizer ---

# No longer need to manually check for CUDA. `device_map="auto"` will handle it.
MODEL_ID = "schrieffer/SARM-4B"

print(f"Loading model: {MODEL_ID} with device_map='auto'...")

# trust_remote_code=True is required because SARM has a custom architecture.
# Using device_map="auto" is the key to correctly loading the model onto the GPU.
model = LlamaSARM.from_pretrained(
    MODEL_ID, 
    sae_hidden_state_source_layer=16, 
    sae_latent_size=65536,
    sae_k=192,
    device_map="auto",  # <<< KEY CHANGE HERE
    torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)

# We can get the device from the model itself after loading
DEVICE = model.device
print(f"Model loaded successfully on device: {DEVICE}")

# --- 2. Define the Inference Function ---
@spaces.GPU
def get_reward_score(prompt: str, response: str) -> float:
    """
    Receives a prompt and a response, and returns the reward score calculated by the SARM model.
    """
    if not prompt or not response:
        return 0.0
        
    try:
        # Use the same chat template as used during model training.
        messages = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response}]
        # The model will handle moving inputs to the correct device automatically.
        input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt") # <<< REMOVED .to(DEVICE)

        with torch.no_grad():
            score = model(input_ids).logits.item()
        
        return round(score, 4)
    except Exception as e:
        print(f"Error: {e}")
        return 0.0

# --- 3. Create and Launch the Gradio Interface ---

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # SARM: Interpretable Reward Model Demo
        
        This is an interactive demo for the **SARM-4B** model (Sparse Autoencoder-enhanced Reward Model).
        
        SARM is a novel reward model architecture that enhances interpretability by integrating a pretrained Sparse Autoencoder (SAE). It maps the internal hidden states of a large language model into a sparse and human-understandable feature space, making the resulting reward scores transparent and conceptually meaningful.
        
        **How to use this Demo:**
        1.  Enter a **Prompt** (e.g., a question) in the left textbox below.
        2.  Enter a corresponding **Response** in the right textbox.
        3.  Click the "Calculate Reward Score" button.
        
        The model will output a scalar score that evaluates the quality of the response. **A higher score indicates that the SARM model considers the response to be of better quality.**

        ---
        
        **SARM Architecture**
        ![framework](assets/framework-v4.png)

        + **Authors** (* indicates equal contribution)

            Shuyi Zhang\*, Wei Shi\*, Sihang Li\*, Jiayi Liao, Tao Liang, Hengxing Cai, Xiang Wang
        + **Paper**: [Interpretable Reward Model via Sparse Autoencoder](https://arxiv.org/abs/2508.08746)

        + **Model**: [schrieffer/SARM-4B](https://huggingface.co/schrieffer/SARM-4B)

            + Finetuned from model: [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct)

        + **Code Repository:** [https://github.com/schrieffer-z/sarm](https://github.com/schrieffer-z/sarm)
        """
    )
    
    with gr.Row():
        prompt_input = gr.Textbox(lines=3, label="Prompt / Question", placeholder="e.g., Can you explain the theory of relativity in simple terms?")
        response_input = gr.Textbox(lines=5, label="Response to be Evaluated", placeholder="e.g., Of course! Albert Einstein's theory of relativity...")

    calculate_btn = gr.Button("Calculate Reward Score", variant="primary")
    score_output = gr.Number(label="Reward Score", info="A higher score is better.")

    calculate_btn.click(
        fn=get_reward_score,
        inputs=[prompt_input, response_input],
        outputs=score_output
    )
    
    gr.Examples(
        examples=[
            ["What is the capital of France?", "The capital of France is Paris."],
            ["What is the capital of France?", "Berlin is a large city in Germany."],
            ["Write a short poem about the moon.", "Silver orb in velvet night, / Casting shadows, soft and light. / Silent watcher, distant, bright, / Guiding dreams till morning's light."],
            ["Write a short poem about the moon.", "The moon is a rock."]
        ],
        inputs=[prompt_input, response_input],
        outputs=score_output,
        fn=get_reward_score,
        cache_examples=True
    )

# Launch the application.
demo.launch()