Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,820 Bytes
d11b63f 3a01517 d11b63f ec9b1de d11b63f 1748050 d11b63f 739e21a d11b63f 739e21a d11b63f 1748050 739e21a ec9b1de d11b63f 05a9ebf 739e21a d11b63f 739e21a d11b63f 1748050 3a01517 d11b63f 1748050 d11b63f 1748050 d11b63f 739e21a d11b63f 1748050 d11b63f 1748050 d11b63f 1748050 df5f30b 1748050 d11b63f 739e21a d11b63f 1748050 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 |
import gradio as gr
import spaces
import torch
from transformers import AutoTokenizer
from sarm_llama import LlamaSARM
# --- 1. Load Model and Tokenizer ---
# No longer need to manually check for CUDA. `device_map="auto"` will handle it.
MODEL_ID = "schrieffer/SARM-4B"
print(f"Loading model: {MODEL_ID} with device_map='auto'...")
# trust_remote_code=True is required because SARM has a custom architecture.
# Using device_map="auto" is the key to correctly loading the model onto the GPU.
model = LlamaSARM.from_pretrained(
MODEL_ID,
sae_hidden_state_source_layer=16,
sae_latent_size=65536,
sae_k=192,
device_map="auto", # <<< KEY CHANGE HERE
torch_dtype=torch.bfloat16
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
# We can get the device from the model itself after loading
DEVICE = model.device
print(f"Model loaded successfully on device: {DEVICE}")
# --- 2. Define the Inference Function ---
@spaces.GPU
def get_reward_score(prompt: str, response: str) -> float:
"""
Receives a prompt and a response, and returns the reward score calculated by the SARM model.
"""
if not prompt or not response:
return 0.0
try:
# Use the same chat template as used during model training.
messages = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response}]
# The model will handle moving inputs to the correct device automatically.
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt") # <<< REMOVED .to(DEVICE)
with torch.no_grad():
score = model(input_ids).logits.item()
return round(score, 4)
except Exception as e:
print(f"Error: {e}")
return 0.0
# --- 3. Create and Launch the Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# SARM: Interpretable Reward Model Demo
This is an interactive demo for the **SARM-4B** model (Sparse Autoencoder-enhanced Reward Model).
SARM is a novel reward model architecture that enhances interpretability by integrating a pretrained Sparse Autoencoder (SAE). It maps the internal hidden states of a large language model into a sparse and human-understandable feature space, making the resulting reward scores transparent and conceptually meaningful.
**How to use this Demo:**
1. Enter a **Prompt** (e.g., a question) in the left textbox below.
2. Enter a corresponding **Response** in the right textbox.
3. Click the "Calculate Reward Score" button.
The model will output a scalar score that evaluates the quality of the response. **A higher score indicates that the SARM model considers the response to be of better quality.**
---
**SARM Architecture**

+ **Authors** (* indicates equal contribution)
Shuyi Zhang\*, Wei Shi\*, Sihang Li\*, Jiayi Liao, Tao Liang, Hengxing Cai, Xiang Wang
+ **Paper**: [Interpretable Reward Model via Sparse Autoencoder](https://arxiv.org/abs/2508.08746)
+ **Model**: [schrieffer/SARM-4B](https://huggingface.co/schrieffer/SARM-4B)
+ Finetuned from model: [Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B-Instruct)
+ **Code Repository:** [https://github.com/schrieffer-z/sarm](https://github.com/schrieffer-z/sarm)
"""
)
with gr.Row():
prompt_input = gr.Textbox(lines=3, label="Prompt / Question", placeholder="e.g., Can you explain the theory of relativity in simple terms?")
response_input = gr.Textbox(lines=5, label="Response to be Evaluated", placeholder="e.g., Of course! Albert Einstein's theory of relativity...")
calculate_btn = gr.Button("Calculate Reward Score", variant="primary")
score_output = gr.Number(label="Reward Score", info="A higher score is better.")
calculate_btn.click(
fn=get_reward_score,
inputs=[prompt_input, response_input],
outputs=score_output
)
gr.Examples(
examples=[
["What is the capital of France?", "The capital of France is Paris."],
["What is the capital of France?", "Berlin is a large city in Germany."],
["Write a short poem about the moon.", "Silver orb in velvet night, / Casting shadows, soft and light. / Silent watcher, distant, bright, / Guiding dreams till morning's light."],
["Write a short poem about the moon.", "The moon is a rock."]
],
inputs=[prompt_input, response_input],
outputs=score_output,
fn=get_reward_score,
cache_examples=True
)
# Launch the application.
demo.launch()
|