hannahcyberey's picture
Upload app.py
8bc8682 verified
import logging, json
import threading
from pathlib import Path
from typing import Dict
import spaces
import pandas as pd
from transformers import TextIteratorStreamer
import gradio as gr
from gradio_toggle import Toggle
from model import load_model
from scheduler import load_scheduler
from schemas import UserRequest, SteeringOutput, CONFIG
logging.basicConfig(level=logging.INFO, format='%(asctime)s %(name)s %(levelname)s:%(message)s')
logger = logging.getLogger(__name__)
model_name = "Llama-3.1-8B-Instruct"
instances = {}
scheduler = load_scheduler()
model = load_model()
examples = pd.read_csv("assets/examples.csv")
HEAD = """
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.7.2/css/all.min.css" integrity="sha512-Evv84Mr4kqVGRNSgIGL/F/aIDqQb7xQ2vcrdIwxfjThSH8CSR7PBEakCr51Ck+w+/U6swU2Im1vVX0SVk9ABhg==" crossorigin="anonymous" referrerpolicy="no-referrer" />
"""
HTML = f"""
<div id="banner">
<h1><img src="/gradio_api/file=assets/rudder_3094973.png">&nbsp;LLM Censorship Steering</h1>
<div id="links" class="row" style="margin-bottom: .8em;">
<i class="fa-solid fa-file-pdf fa-lg"></i><a href="https://arxiv.org/abs/2504.17130"> Paper</a> &nbsp;
<i class="fa-solid fa-blog fa-lg"></i><a href="https://hannahxchen.github.io/blog/2025/censorship-steering"> Blog Post</a> &nbsp;
<i class="fa-brands fa-github fa-lg"></i><a href="https://github.com/hannahxchen/llm-censorship-steering"> Code</a> &nbsp;
</div>
<div id="cover">
<img src="/gradio_api/file=assets/demo-cover.png">
</div>
</div>
"""
CSS = """
div.gradio-container .app {
max-width: 1600px !important;
}
div#banner {
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
h1 {
font-size: 32px;
line-height: 1.35em;
margin-bottom: 0em;
display: flex;
img {
display: inline;
height: 1.35em;
}
}
div#cover img {
max-height: 130px;
padding-top: 0.5em;
}
}
@media (max-width: 500px) {
div#banner {
h1 {
font-size: 22px;
}
div#links {
font-size: 14px;
}
}
div#model-state p {
font-size: 14px;
}
}
div#main-components {
align-items: flex-end;
}
div#steering-toggle {
padding-top: 8px;
padding-bottom: 8px;
.toggle-label {
color: var(--body-text-color);
}
span p {
font-size: var(--block-info-text-size);
line-height: var(--line-sm);
color: var(--block-label-text-color);
}
}
div#coeff-slider {
padding-bottom: 5px;
.slider_input_container span {color: var(--body-text-color);}
.slider_input_container {
display: flex;
flex-wrap: wrap;
input {appearance: auto;}
}
}
div#coeff-slider .wrap .head {
justify-content: unset;
label {margin-right: var(--size-2);}
label span {
color: var(--body-text-color);
margin-bottom: 0;
}
}
"""
slider_info = """\
<div style='display: flex; justify-content: space-between; line-height: normal;'>\
<span style='font-size: var(--block-info-text-size); color: var(--block-label-text-color);'>Less censorship</span>\
<span style='font-size: var(--block-info-text-size); color: var(--block-label-text-color);'>More censorship</span>\
</div>\
"""\
slider_ticks = """\
<datalist id='values' style='display: flex; justify-content: space-between; width: 100%; padding: 0 6px;'>\
<option value='-2' style='font-size: 13px; line-height: var(--spacing-xs); width: 1px; display: flex; justify-content: center;'>-2</option>\
<option value='-1' style='font-size: 13px; line-height: var(--spacing-xs); width: 1px; display: flex; justify-content: center;'>-1</option>\
<option value='0' style='font-size: 13px; line-height: var(--spacing-xs); width: 1px; display: flex; justify-content: center;'>0</option>\
<option value='1' style='font-size: 13px; line-height: var(--spacing-xs); width: 1px; display: flex; justify-content: center;'>1</option>\
<option value='2' style='font-size: 13px; line-height: var(--spacing-xs); width: 1px; display: flex; justify-content: center;'>2</option>\
</datalist>\
"""
JS = """
async() => {
const node = document.querySelector("div.slider_input_container");
node.insertAdjacentHTML('beforebegin', "%s");
const sliderNode = document.querySelector("input#range_id_0");
sliderNode.insertAdjacentHTML('afterend', "%s");
sliderNode.setAttribute("list", "values");
document.querySelector('span.min_value').remove();
document.querySelector('span.max_value').remove();
}
""" % (slider_info, slider_ticks)
def initialize_instance(request: gr.Request):
instances[request.session_hash] = []
logger.info("Number of connections: %d", len(instances))
return request.session_hash
def cleanup_instance(request: gr.Request):
session_id = request.session_hash
if session_id in instances:
for data in instances[session_id]:
if isinstance(data, SteeringOutput):
scheduler.append(data.model_dump())
del instances[session_id]
logger.info("Number of connections: %d", len(instances))
@spaces.GPU(duration=90)
def generate(prompt: str, steering: bool, coeff: float, generation_config: Dict[str, float]):
streamer = TextIteratorStreamer(model.tokenizer, timeout=10, skip_prompt=True, skip_special_tokens=True)
thread = threading.Thread(
target=model.generate,
args=(prompt, streamer, steering, coeff, generation_config)
)
thread.start()
generated_text = ""
for new_text in streamer:
generated_text += new_text
yield generated_text
def generate_output(
session_id: str, prompt: str, steering: bool, coeff: float,
max_new_tokens: int, top_p: float, temperature: float
):
req = UserRequest(
session_id=session_id, prompt=prompt, steering=steering, coeff=coeff,
max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature
)
instances[session_id].append(req)
yield from generate(prompt, steering, coeff, req.generation_config())
async def post_process(session_id, output):
req = instances[session_id].pop()
steering_output = SteeringOutput(**req.model_dump(), output=output)
instances[session_id].append(steering_output)
return gr.update(interactive=True), gr.update(interactive=True)
async def output_feedback(session_id, feedback):
try:
data = instances[session_id].pop()
if "Upvote" in feedback:
setattr(data, "upvote", True)
elif "Downvote" in feedback:
setattr(data, "upvote", False)
instances[session_id].append(data)
gr.Info("Thank you for your feedback!")
except:
logger.debug("Feedback submission error")
gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
theme = gr.themes.Base(primary_hue="emerald", text_size=gr.themes.sizes.text_lg).set()
with gr.Blocks(title="LLM Censorship Steering", theme=theme, head=HEAD, css=CSS, js=JS) as demo:
session_id = gr.State()
gr.HTML(HTML)
with gr.Row(elem_id="main-components"):
with gr.Column(scale=1):
gr.Markdown(f'🤖 {model_name}')
with gr.Row():
steer_toggle = Toggle(label="Steering", info="Turn off to generate original outputs", value=True, interactive=True, scale=2, elem_id="steering-toggle")
coeff = gr.Slider(label="Coefficient:", value=-1.0, minimum=-2, maximum=2, step=0.1, scale=8, show_reset_button=False, elem_id="coeff-slider")
@gr.on(inputs=[steer_toggle], outputs=[steer_toggle, coeff], triggers=[steer_toggle.change])
def update_toggle(toggle_value):
if toggle_value is True:
return gr.update(label="Steering", info="Turn off to generate original outputs"), gr.update(interactive=True)
else:
return gr.update(label="No Steering", info="Turn on to steer model outputs"), gr.update(interactive=False)
with gr.Accordion("⚙️ Advanced Settings", open=False):
with gr.Row():
temperature = gr.Slider(0, 1, step=0.1, value=CONFIG["temperature"], interactive=True, label="Temperature", scale=2)
top_p = gr.Slider(0, 1, step=0.1, value=CONFIG["top_p"], interactive=True, label="Top p", scale=2)
max_new_tokens = gr.Number(CONFIG["max_new_tokens"], minimum=10, maximum=CONFIG["max_new_tokens"], interactive=True, label="Max new tokens", scale=1)
input_text = gr.Textbox(label="Input", placeholder="Enter your prompt here...", lines=6, interactive=True)
with gr.Row():
clear_btn = gr.ClearButton()
generate_btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=1):
output = gr.Textbox(label="Output", lines=15, max_lines=15, interactive=False)
with gr.Row():
upvote_btn = gr.Button("👍 Upvote", interactive=False)
downvote_btn = gr.Button("👎 Downvote", interactive=False)
gr.HTML("<p>‼️ For research purposes, we log user inputs and generated outputs. Please avoid submitting any confidential or personal information.</p>")
gr.Markdown("#### Examples")
gr.Examples(examples=examples[examples["type"] == "harmful"].prompt.tolist(), inputs=input_text, label="Harmful")
gr.Examples(examples=examples[examples["type"] == "harmless"].prompt.tolist(), inputs=input_text, label="Harmless")
@gr.on(triggers=[clear_btn.click], outputs=[upvote_btn, downvote_btn])
def clear():
return gr.update(interactive=False), gr.update(interactive=False)
clear_btn.add([input_text, output])
generate_btn.click(
generate_output, inputs=[session_id, input_text, steer_toggle, coeff, max_new_tokens, top_p, temperature], outputs=output
).success(
post_process, inputs=[session_id, output], outputs=[upvote_btn, downvote_btn]
)
upvote_btn.click(output_feedback, inputs=[session_id, upvote_btn])
downvote_btn.click(output_feedback, inputs=[session_id, downvote_btn])
demo.load(initialize_instance, outputs=session_id)
demo.unload(cleanup_instance)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=5)
demo.launch(debug=True)