moral-compass / moral_compass_demo.py
willsh1997's picture
Revert ":wrench: TEST no quantisation"
50712a4
import spaces
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM, GenerationConfig
import torch
from transformers import pipeline
import pandas as pd
import gradio as gr
from concurrent.futures import ThreadPoolExecutor
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
qwen_model_name = "Qwen/Qwen3-0.6B"
# load the tokenizer and the model
tokenizer = AutoTokenizer.from_pretrained(qwen_model_name)
qwen_model = AutoModelForCausalLM.from_pretrained(
qwen_model_name,
torch_dtype=torch.bfloat16, #testing for underflow issues
device_map="auto",
quantization_config = quantization_config,
)
qwen_generationconfig = GenerationConfig(
max_new_tokens=512,
temperature = 0.7,
top_p = 0.8,
min_p = 0
)
@spaces.GPU
def qwen_generate(input_question):
# prepare the model input
messages = [
{"role": "user", "content": input_question}
]
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False # Switches between thinking and non-thinking modes. Default is True.
)
model_inputs = tokenizer([text], return_tensors="pt").to(qwen_model.device)
print("tokenized")
# conduct text completion
generated_ids = qwen_model.generate(
**model_inputs,
generation_config = qwen_generationconfig,
)
print("outputs generated")
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
content = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n")
return content
#Llama 2 7b chat setup
llama2_model_id = "unsloth/gemma-3-1b-it-bnb-4bit"
llama2_pipe = pipeline(
"text-generation",
model=llama2_model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
model_kwargs={"quantization_config": quantization_config},
# quantization_config=quantization_config,
)
#Llama 3.2 3b setup
llama3_model_id = "meta-llama/Llama-3.2-3B-Instruct"
llama3_pipe = pipeline(
"text-generation",
model=llama3_model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
model_kwargs={"quantization_config": quantization_config},
)
@spaces.GPU
def llama_QA(input_question, pipe):
"""
stupid func for asking llama a question and then getting an answer
inputs:
- input_question [str]: question for llama to answer
outputs:
- response [str]: llama's response
"""
messages = [
{"role": "system", "content": "You are a helpful chatbot assistant. Answer all questions in the language they are asked in. Exclude any answer that you do not have real time information, just provide the information you have to answer this question."},
{"role": "user", "content": input_question},
]
outputs = pipe(
messages,
max_new_tokens=512
)
response = outputs[0]["generated_text"][-1]['content']
return response
@spaces.GPU
def gradio_func(input_question, left_lang, right_lang):
"""
silly wrapper function for gradio that turns all inputs into a single func. runs both the LHS and RHS of teh 'app' in order to let gradio work correctly.
"""
# with ThreadPoolExecutor(max_workers=3) as executor:
# # Submit all tasks simultaneously
# future1 = executor.submit(qwen_generate, input_question)
# future2 = executor.submit(llama_QA, input_question, llama2_pipe)
# future3 = executor.submit(llama_QA, input_question, llama3_pipe)
# Collect results
output1 = qwen_generate(input_question) #future1.result()
output2 = llama_QA(input_question, llama2_pipe) #future2.result()
output3 = llama_QA(input_question, llama3_pipe) #future3.result()
return output1, output2, output3
# Create the Gradio interface
def create_interface():
with gr.Blocks() as demo:
with gr.Row():
question_input = gr.Textbox(label="Enter your question", interactive=True, value = """You are at a train track junction, with a lever in front of you that changes the active track. On one track, one of your loved ones is tied to the track, on the other, there are 5 strangers tied up, one after the other. If you do not use the lever in front of you to change the active track, then the 5 strangers die. If you do pull the lever, your loved one dies. What do you do?""")
with gr.Row():
submit_btn = gr.Button("Ask")
with gr.Row():
output1 = gr.Textbox(label="Qwen 3 output", interactive=False)
output2 = gr.Textbox(label="Gemma 3 output", interactive=False)
output3 = gr.Textbox(label="Llama 3 output", interactive=False)
submit_btn.click(
fn=gradio_func,
inputs=[question_input],
outputs=[
output1,
output2,
output3,
]
)
return demo
# Launch the app
demo = create_interface()
demo.launch()