Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM, GenerationConfig | |
import torch | |
from transformers import pipeline | |
import pandas as pd | |
import gradio as gr | |
from concurrent.futures import ThreadPoolExecutor | |
quantization_config = BitsAndBytesConfig(load_in_4bit=True) | |
qwen_model_name = "Qwen/Qwen3-0.6B" | |
# load the tokenizer and the model | |
tokenizer = AutoTokenizer.from_pretrained(qwen_model_name) | |
qwen_model = AutoModelForCausalLM.from_pretrained( | |
qwen_model_name, | |
torch_dtype=torch.bfloat16, #testing for underflow issues | |
device_map="auto", | |
quantization_config = quantization_config, | |
) | |
qwen_generationconfig = GenerationConfig( | |
max_new_tokens=512, | |
temperature = 0.7, | |
top_p = 0.8, | |
min_p = 0 | |
) | |
def qwen_generate(input_question): | |
# prepare the model input | |
messages = [ | |
{"role": "user", "content": input_question} | |
] | |
text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True, | |
enable_thinking=False # Switches between thinking and non-thinking modes. Default is True. | |
) | |
model_inputs = tokenizer([text], return_tensors="pt").to(qwen_model.device) | |
print("tokenized") | |
# conduct text completion | |
generated_ids = qwen_model.generate( | |
**model_inputs, | |
generation_config = qwen_generationconfig, | |
) | |
print("outputs generated") | |
output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() | |
content = tokenizer.decode(output_ids, skip_special_tokens=True).strip("\n") | |
return content | |
#Llama 2 7b chat setup | |
llama2_model_id = "unsloth/gemma-3-1b-it-bnb-4bit" | |
llama2_pipe = pipeline( | |
"text-generation", | |
model=llama2_model_id, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
model_kwargs={"quantization_config": quantization_config}, | |
# quantization_config=quantization_config, | |
) | |
#Llama 3.2 3b setup | |
llama3_model_id = "meta-llama/Llama-3.2-3B-Instruct" | |
llama3_pipe = pipeline( | |
"text-generation", | |
model=llama3_model_id, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
model_kwargs={"quantization_config": quantization_config}, | |
) | |
def llama_QA(input_question, pipe): | |
""" | |
stupid func for asking llama a question and then getting an answer | |
inputs: | |
- input_question [str]: question for llama to answer | |
outputs: | |
- response [str]: llama's response | |
""" | |
messages = [ | |
{"role": "system", "content": "You are a helpful chatbot assistant. Answer all questions in the language they are asked in. Exclude any answer that you do not have real time information, just provide the information you have to answer this question."}, | |
{"role": "user", "content": input_question}, | |
] | |
outputs = pipe( | |
messages, | |
max_new_tokens=512 | |
) | |
response = outputs[0]["generated_text"][-1]['content'] | |
return response | |
def gradio_func(input_question, left_lang, right_lang): | |
""" | |
silly wrapper function for gradio that turns all inputs into a single func. runs both the LHS and RHS of teh 'app' in order to let gradio work correctly. | |
""" | |
# with ThreadPoolExecutor(max_workers=3) as executor: | |
# # Submit all tasks simultaneously | |
# future1 = executor.submit(qwen_generate, input_question) | |
# future2 = executor.submit(llama_QA, input_question, llama2_pipe) | |
# future3 = executor.submit(llama_QA, input_question, llama3_pipe) | |
# Collect results | |
output1 = qwen_generate(input_question) #future1.result() | |
output2 = llama_QA(input_question, llama2_pipe) #future2.result() | |
output3 = llama_QA(input_question, llama3_pipe) #future3.result() | |
return output1, output2, output3 | |
# Create the Gradio interface | |
def create_interface(): | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
question_input = gr.Textbox(label="Enter your question", interactive=True, value = """You are at a train track junction, with a lever in front of you that changes the active track. On one track, one of your loved ones is tied to the track, on the other, there are 5 strangers tied up, one after the other. If you do not use the lever in front of you to change the active track, then the 5 strangers die. If you do pull the lever, your loved one dies. What do you do?""") | |
with gr.Row(): | |
submit_btn = gr.Button("Ask") | |
with gr.Row(): | |
output1 = gr.Textbox(label="Qwen 3 output", interactive=False) | |
output2 = gr.Textbox(label="Gemma 3 output", interactive=False) | |
output3 = gr.Textbox(label="Llama 3 output", interactive=False) | |
submit_btn.click( | |
fn=gradio_func, | |
inputs=[question_input], | |
outputs=[ | |
output1, | |
output2, | |
output3, | |
] | |
) | |
return demo | |
# Launch the app | |
demo = create_interface() | |
demo.launch() | |