Schrieffer2sy commited on
Commit
739e21a
·
1 Parent(s): 06d7c11
Files changed (1) hide show
  1. app.py +10 -12
app.py CHANGED
@@ -4,30 +4,30 @@ from transformers import AutoTokenizer
4
  from sarm_llama import LlamaSARM
5
 
6
  # --- 1. Load Model and Tokenizer ---
7
- # This step automatically downloads your model files from the Hugging Face Hub.
8
- # Ensure your model repository is public.
9
 
10
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
  MODEL_ID = "schrieffer/SARM-4B"
12
 
13
- print(f"Loading model: {MODEL_ID} on {DEVICE}...")
14
 
15
  # trust_remote_code=True is required because SARM has a custom architecture.
 
16
  model = LlamaSARM.from_pretrained(
17
  MODEL_ID,
18
  sae_hidden_state_source_layer=16,
19
  sae_latent_size=65536,
20
  sae_k=192,
21
- device_map=DEVICE,
22
  trust_remote_code=True,
23
  torch_dtype=torch.bfloat16
24
  )
25
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
26
 
27
- print("Model loaded successfully!")
 
 
28
 
29
  # --- 2. Define the Inference Function ---
30
- # This function will be called by Gradio.
31
 
32
  def get_reward_score(prompt: str, response: str) -> float:
33
  """
@@ -39,7 +39,8 @@ def get_reward_score(prompt: str, response: str) -> float:
39
  try:
40
  # Use the same chat template as used during model training.
41
  messages = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response}]
42
- input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(DEVICE)
 
43
 
44
  with torch.no_grad():
45
  score = model(input_ids).logits.item()
@@ -47,12 +48,10 @@ def get_reward_score(prompt: str, response: str) -> float:
47
  return round(score, 4)
48
  except Exception as e:
49
  print(f"Error: {e}")
50
- # It might be better to return an error message on the UI, but here we simply return 0.
51
  return 0.0
52
 
53
  # --- 3. Create and Launch the Gradio Interface ---
54
 
55
- # Use gr.Blocks() for a more flexible layout.
56
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
57
  gr.Markdown(
58
  """
@@ -94,7 +93,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
94
  calculate_btn = gr.Button("Calculate Reward Score", variant="primary")
95
  score_output = gr.Number(label="Reward Score", info="A higher score is better.")
96
 
97
- # Define the button's click behavior.
98
  calculate_btn.click(
99
  fn=get_reward_score,
100
  inputs=[prompt_input, response_input],
@@ -111,7 +109,7 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
111
  inputs=[prompt_input, response_input],
112
  outputs=score_output,
113
  fn=get_reward_score,
114
- cache_examples=True # Cache the results of the examples to speed up loading.
115
  )
116
 
117
  # Launch the application.
 
4
  from sarm_llama import LlamaSARM
5
 
6
  # --- 1. Load Model and Tokenizer ---
 
 
7
 
8
+ # No longer need to manually check for CUDA. `device_map="auto"` will handle it.
9
  MODEL_ID = "schrieffer/SARM-4B"
10
 
11
+ print(f"Loading model: {MODEL_ID} with device_map='auto'...")
12
 
13
  # trust_remote_code=True is required because SARM has a custom architecture.
14
+ # Using device_map="auto" is the key to correctly loading the model onto the GPU.
15
  model = LlamaSARM.from_pretrained(
16
  MODEL_ID,
17
  sae_hidden_state_source_layer=16,
18
  sae_latent_size=65536,
19
  sae_k=192,
20
+ device_map="auto", # <<< KEY CHANGE HERE
21
  trust_remote_code=True,
22
  torch_dtype=torch.bfloat16
23
  )
24
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
25
 
26
+ # We can get the device from the model itself after loading
27
+ DEVICE = model.device
28
+ print(f"Model loaded successfully on device: {DEVICE}")
29
 
30
  # --- 2. Define the Inference Function ---
 
31
 
32
  def get_reward_score(prompt: str, response: str) -> float:
33
  """
 
39
  try:
40
  # Use the same chat template as used during model training.
41
  messages = [{"role": "user", "content": prompt}, {"role": "assistant", "content": response}]
42
+ # The model will handle moving inputs to the correct device automatically.
43
+ input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt") # <<< REMOVED .to(DEVICE)
44
 
45
  with torch.no_grad():
46
  score = model(input_ids).logits.item()
 
48
  return round(score, 4)
49
  except Exception as e:
50
  print(f"Error: {e}")
 
51
  return 0.0
52
 
53
  # --- 3. Create and Launch the Gradio Interface ---
54
 
 
55
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
56
  gr.Markdown(
57
  """
 
93
  calculate_btn = gr.Button("Calculate Reward Score", variant="primary")
94
  score_output = gr.Number(label="Reward Score", info="A higher score is better.")
95
 
 
96
  calculate_btn.click(
97
  fn=get_reward_score,
98
  inputs=[prompt_input, response_input],
 
109
  inputs=[prompt_input, response_input],
110
  outputs=score_output,
111
  fn=get_reward_score,
112
+ cache_examples=True
113
  )
114
 
115
  # Launch the application.