import logging, json import threading from pathlib import Path from typing import Dict import spaces import pandas as pd from transformers import TextIteratorStreamer import gradio as gr from gradio_toggle import Toggle from model import load_model from scheduler import load_scheduler from schemas import UserRequest, SteeringOutput, CONFIG logging.basicConfig(level=logging.INFO, format='%(asctime)s %(name)s %(levelname)s:%(message)s') logger = logging.getLogger(__name__) model_name = "Llama-3.1-8B-Instruct" instances = {} scheduler = load_scheduler() model = load_model() examples = pd.read_csv("assets/examples.csv") HEAD = """ """ HTML = f""" """ CSS = """ div.gradio-container .app { max-width: 1600px !important; } div#banner { display: flex; flex-direction: column; align-items: center; justify-content: center; h1 { font-size: 32px; line-height: 1.35em; margin-bottom: 0em; display: flex; img { display: inline; height: 1.35em; } } div#cover img { max-height: 130px; padding-top: 0.5em; } } @media (max-width: 500px) { div#banner { h1 { font-size: 22px; } div#links { font-size: 14px; } } div#model-state p { font-size: 14px; } } div#main-components { align-items: flex-end; } div#steering-toggle { padding-top: 8px; padding-bottom: 8px; .toggle-label { color: var(--body-text-color); } span p { font-size: var(--block-info-text-size); line-height: var(--line-sm); color: var(--block-label-text-color); } } div#coeff-slider { padding-bottom: 5px; .slider_input_container span {color: var(--body-text-color);} .slider_input_container { display: flex; flex-wrap: wrap; input {appearance: auto;} } } div#coeff-slider .wrap .head { justify-content: unset; label {margin-right: var(--size-2);} label span { color: var(--body-text-color); margin-bottom: 0; } } """ slider_info = """\
\ Less censorship\ More censorship\
\ """\ slider_ticks = """\ \ \ \ \ \ \ \ """ JS = """ async() => { const node = document.querySelector("div.slider_input_container"); node.insertAdjacentHTML('beforebegin', "%s"); const sliderNode = document.querySelector("input#range_id_0"); sliderNode.insertAdjacentHTML('afterend', "%s"); sliderNode.setAttribute("list", "values"); document.querySelector('span.min_value').remove(); document.querySelector('span.max_value').remove(); } """ % (slider_info, slider_ticks) def initialize_instance(request: gr.Request): instances[request.session_hash] = [] logger.info("Number of connections: %d", len(instances)) return request.session_hash def cleanup_instance(request: gr.Request): session_id = request.session_hash if session_id in instances: for data in instances[session_id]: if isinstance(data, SteeringOutput): scheduler.append(data.model_dump()) del instances[session_id] logger.info("Number of connections: %d", len(instances)) @spaces.GPU(duration=90) def generate(prompt: str, steering: bool, coeff: float, generation_config: Dict[str, float]): streamer = TextIteratorStreamer(model.tokenizer, timeout=10, skip_prompt=True, skip_special_tokens=True) thread = threading.Thread( target=model.generate, args=(prompt, streamer, steering, coeff, generation_config) ) thread.start() generated_text = "" for new_text in streamer: generated_text += new_text yield generated_text def generate_output( session_id: str, prompt: str, steering: bool, coeff: float, max_new_tokens: int, top_p: float, temperature: float ): req = UserRequest( session_id=session_id, prompt=prompt, steering=steering, coeff=coeff, max_new_tokens=max_new_tokens, top_p=top_p, temperature=temperature ) instances[session_id].append(req) yield from generate(prompt, steering, coeff, req.generation_config()) async def post_process(session_id, output): req = instances[session_id].pop() steering_output = SteeringOutput(**req.model_dump(), output=output) instances[session_id].append(steering_output) return gr.update(interactive=True), gr.update(interactive=True) async def output_feedback(session_id, feedback): try: data = instances[session_id].pop() if "Upvote" in feedback: setattr(data, "upvote", True) elif "Downvote" in feedback: setattr(data, "upvote", False) instances[session_id].append(data) gr.Info("Thank you for your feedback!") except: logger.debug("Feedback submission error") gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"]) theme = gr.themes.Base(primary_hue="emerald", text_size=gr.themes.sizes.text_lg).set() with gr.Blocks(title="LLM Censorship Steering", theme=theme, head=HEAD, css=CSS, js=JS) as demo: session_id = gr.State() gr.HTML(HTML) with gr.Row(elem_id="main-components"): with gr.Column(scale=1): gr.Markdown(f'🤖 {model_name}') with gr.Row(): steer_toggle = Toggle(label="Steering", info="Turn off to generate original outputs", value=True, interactive=True, scale=2, elem_id="steering-toggle") coeff = gr.Slider(label="Coefficient:", value=-1.0, minimum=-2, maximum=2, step=0.1, scale=8, show_reset_button=False, elem_id="coeff-slider") @gr.on(inputs=[steer_toggle], outputs=[steer_toggle, coeff], triggers=[steer_toggle.change]) def update_toggle(toggle_value): if toggle_value is True: return gr.update(label="Steering", info="Turn off to generate original outputs"), gr.update(interactive=True) else: return gr.update(label="No Steering", info="Turn on to steer model outputs"), gr.update(interactive=False) with gr.Accordion("⚙️ Advanced Settings", open=False): with gr.Row(): temperature = gr.Slider(0, 1, step=0.1, value=CONFIG["temperature"], interactive=True, label="Temperature", scale=2) top_p = gr.Slider(0, 1, step=0.1, value=CONFIG["top_p"], interactive=True, label="Top p", scale=2) max_new_tokens = gr.Number(CONFIG["max_new_tokens"], minimum=10, maximum=CONFIG["max_new_tokens"], interactive=True, label="Max new tokens", scale=1) input_text = gr.Textbox(label="Input", placeholder="Enter your prompt here...", lines=6, interactive=True) with gr.Row(): clear_btn = gr.ClearButton() generate_btn = gr.Button("Generate", variant="primary") with gr.Column(scale=1): output = gr.Textbox(label="Output", lines=15, max_lines=15, interactive=False) with gr.Row(): upvote_btn = gr.Button("👍 Upvote", interactive=False) downvote_btn = gr.Button("👎 Downvote", interactive=False) gr.HTML("

‼️ For research purposes, we log user inputs and generated outputs. Please avoid submitting any confidential or personal information.

") gr.Markdown("#### Examples") gr.Examples(examples=examples[examples["type"] == "harmful"].prompt.tolist(), inputs=input_text, label="Harmful") gr.Examples(examples=examples[examples["type"] == "harmless"].prompt.tolist(), inputs=input_text, label="Harmless") @gr.on(triggers=[clear_btn.click], outputs=[upvote_btn, downvote_btn]) def clear(): return gr.update(interactive=False), gr.update(interactive=False) clear_btn.add([input_text, output]) generate_btn.click( generate_output, inputs=[session_id, input_text, steer_toggle, coeff, max_new_tokens, top_p, temperature], outputs=output ).success( post_process, inputs=[session_id, output], outputs=[upvote_btn, downvote_btn] ) upvote_btn.click(output_feedback, inputs=[session_id, upvote_btn]) downvote_btn.click(output_feedback, inputs=[session_id, downvote_btn]) demo.load(initialize_instance, outputs=session_id) demo.unload(cleanup_instance) if __name__ == "__main__": demo.queue(default_concurrency_limit=5) demo.launch(debug=True)