Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import torch | |
import gradio as gr | |
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
# ------------------------------------------------- | |
# Model setup (loaded once at startup) | |
# ------------------------------------------------- | |
model_name = "gr0010/Art-0-8B-development" | |
# Load model and tokenizer globally | |
print("Loading model and tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) | |
# Load model in CPU first, will move to GPU when needed | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
device_map="cuda", # Direct CUDA loading for ZeroGPU | |
trust_remote_code=True, | |
) | |
print("Model loaded successfully!") | |
# ------------------------------------------------- | |
# Core generation and parsing logic with Zero GPU | |
# ------------------------------------------------- | |
# Request GPU for up to 120 seconds | |
def generate_and_parse(messages: list, temperature: float = 0.6, | |
top_p: float = 0.95, top_k: int = 20, | |
min_p: float = 0.0, max_new_tokens: int = 32768): | |
""" | |
Takes a clean list of messages, generates a response, | |
and parses it into thinking and answer parts. | |
Decorated with @spaces.GPU for Zero GPU allocation. | |
""" | |
# Apply chat template with enable_thinking=True for Qwen3 | |
prompt_text = tokenizer.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True, | |
enable_thinking=True # Explicitly enable thinking mode | |
) | |
# --- CONSOLE DEBUG OUTPUT --- | |
print("\n" + "="*50) | |
print("--- RAW PROMPT SENT TO MODEL ---") | |
print(prompt_text[:500] + "..." if len(prompt_text) > 500 else prompt_text) | |
print("="*50 + "\n") | |
model_inputs = tokenizer([prompt_text], return_tensors="pt").to("cuda") | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
**model_inputs, | |
max_new_tokens=max_new_tokens, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
top_k=top_k, | |
min_p=min_p, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
output_token_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist() | |
thinking = "" | |
answer = "" | |
try: | |
# Find the </think> token to separate thinking from answer | |
end_think_token_id = 151668 # </think> | |
if end_think_token_id in output_token_ids: | |
end_think_idx = output_token_ids.index(end_think_token_id) + 1 | |
thinking_tokens = output_token_ids[:end_think_idx] | |
answer_tokens = output_token_ids[end_think_idx:] | |
thinking = tokenizer.decode(thinking_tokens, skip_special_tokens=True).strip() | |
# Remove <think> and </think> tags from thinking | |
thinking = thinking.replace("<think>", "").replace("</think>", "").strip() | |
answer = tokenizer.decode(answer_tokens, skip_special_tokens=True).strip() | |
else: | |
# If no </think> token found, treat everything as answer | |
answer = tokenizer.decode(output_token_ids, skip_special_tokens=True).strip() | |
# Remove any stray <think> tags | |
answer = answer.replace("<think>", "").replace("</think>", "") | |
except (ValueError, IndexError): | |
answer = tokenizer.decode(output_token_ids, skip_special_tokens=True).strip() | |
answer = answer.replace("<think>", "").replace("</think>", "") | |
return thinking, answer | |
# ------------------------------------------------- | |
# Gradio UI Logic | |
# ------------------------------------------------- | |
# Custom CSS for better styling | |
custom_css = """ | |
.model-info { | |
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); | |
padding: 1rem; | |
border-radius: 10px; | |
margin-bottom: 1rem; | |
color: white; | |
} | |
.model-info a { | |
color: #fff; | |
text-decoration: underline; | |
font-weight: bold; | |
} | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(), fill_height=True, css=custom_css) as demo: | |
# Separate states for display and model context | |
display_history_state = gr.State([]) # For Gradio chatbot display (with HTML formatting) | |
model_history_state = gr.State([]) # Clean history for model (plain text only) | |
is_generating_state = gr.State(False) # To prevent multiple submissions | |
# Model info and CTA section | |
gr.HTML(""" | |
<div class="model-info"> | |
<h1 style="margin: 0; font-size: 2em;">π¨ Art-0 8B Thinking Chatbot</h1> | |
<p style="margin: 0.5rem 0;"> | |
Powered by <a href="https://huggingface.co/gr0010/Art-0-8B-development" target="_blank">Art-0-8B-development</a> | |
- A fine-tuned Qwen3-8B model with advanced reasoning capabilities | |
</p> | |
</div> | |
""") | |
gr.Markdown( | |
""" | |
Chat with Art-0-8B, featuring transparent reasoning display and custom personality instructions. | |
The model shows its internal thought process when solving problems. | |
""" | |
) | |
# System prompt at the top (main feature) | |
with gr.Group(): | |
gr.Markdown("### π System Prompt (Personality & Behavior)") | |
system_prompt = gr.Textbox( | |
value="""Personality Instructions: | |
You are an AI assistant named Art developed by AGI-0. | |
Reasoning Instructions: | |
Think using bullet points and short sentences to simulate thoughts and emoticons to simulate emotions""", | |
label="System Prompt", | |
info="Define the model's personality and reasoning style", | |
lines=5, | |
interactive=True | |
) | |
# Main chat interface | |
chatbot = gr.Chatbot( | |
label="Conversation", | |
elem_id="chatbot", | |
bubble_full_width=False, | |
height=500, | |
show_copy_button=True, | |
type="messages" | |
) | |
with gr.Row(): | |
user_input = gr.Textbox( | |
show_label=False, | |
placeholder="Type your message here...", | |
scale=4, | |
container=False, | |
interactive=True | |
) | |
submit_btn = gr.Button( | |
"Send", | |
variant="primary", | |
scale=1, | |
interactive=True | |
) | |
with gr.Row(): | |
clear_btn = gr.Button("ποΈ Clear History", variant="secondary") | |
retry_btn = gr.Button("π Retry Last", variant="secondary") | |
# Example prompts | |
gr.Examples( | |
examples=[ | |
["Give me a short introduction to large language models."], | |
["What are the benefits of using transformers in AI?"], | |
["There are 5 birds on a branch. A hunter shoots one. How many birds are left?"], | |
["Explain quantum computing step by step."], | |
["Write a Python function to calculate the factorial of a number."], | |
["What makes Art-0 different from other AI models?"], | |
], | |
inputs=user_input, | |
label="π‘ Example Prompts" | |
) | |
# Advanced settings at the bottom | |
with gr.Accordion("βοΈ Advanced Generation Settings", open=False): | |
with gr.Row(): | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=2.0, | |
value=0.6, | |
step=0.1, | |
label="Temperature", | |
info="Controls randomness (higher = more creative)" | |
) | |
top_p = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p", | |
info="Nucleus sampling threshold" | |
) | |
with gr.Row(): | |
top_k = gr.Slider( | |
minimum=1, | |
maximum=100, | |
value=20, | |
step=1, | |
label="Top-k", | |
info="Number of top tokens to consider" | |
) | |
min_p = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.0, | |
step=0.01, | |
label="Min-p", | |
info="Minimum probability threshold for token sampling" | |
) | |
with gr.Row(): | |
max_new_tokens = gr.Slider( | |
minimum=128, | |
maximum=32768, | |
value=32768, | |
step=128, | |
label="Max New Tokens", | |
info="Maximum response length" | |
) | |
def handle_user_message(user_message: str, display_history: list, model_history: list, | |
system_prompt_text: str, is_generating: bool, | |
temp: float, top_p_val: float, top_k_val: int, | |
min_p_val: float, max_tokens: int): | |
""" | |
Handles user input, updates histories, and generates the model's response. | |
""" | |
# Prevent multiple submissions | |
if is_generating or not user_message.strip(): | |
return { | |
chatbot: display_history, | |
display_history_state: display_history, | |
model_history_state: model_history, | |
is_generating_state: is_generating, | |
user_input: user_message, | |
submit_btn: gr.update(interactive=not is_generating) | |
} | |
# Set generating state | |
is_generating = True | |
# Update model history (clean format for model - PLAIN TEXT ONLY) | |
model_history.append({"role": "user", "content": user_message.strip()}) | |
# Update display history (for Gradio chatbot) | |
display_history.append({"role": "user", "content": user_message.strip()}) | |
# Yield intermediate state to show user message and disable input | |
yield { | |
chatbot: display_history, | |
display_history_state: display_history, | |
model_history_state: model_history, | |
is_generating_state: is_generating, | |
user_input: "", | |
submit_btn: gr.update(interactive=False, value="π Generating...") | |
} | |
# Prepare messages for model (include system prompt) | |
messages_for_model = [] | |
if system_prompt_text.strip(): | |
messages_for_model.append({"role": "system", "content": system_prompt_text.strip()}) | |
messages_for_model.extend(model_history) | |
try: | |
# Generate response with hyperparameters | |
thinking, answer = generate_and_parse( | |
messages_for_model, | |
temperature=temp, | |
top_p=top_p_val, | |
top_k=top_k_val, | |
min_p=min_p_val, | |
max_new_tokens=max_tokens | |
) | |
# Update model history with CLEAN answer (no HTML formatting) | |
model_history.append({"role": "assistant", "content": answer}) | |
# Format response for display (with HTML formatting) | |
if thinking and thinking.strip(): | |
formatted_response = f"""<details> | |
<summary><b>π€ Show Reasoning Process</b></summary> | |
{thinking} | |
</details> | |
{answer}""" | |
else: | |
formatted_response = answer | |
# Update display history with formatted response | |
display_history.append({"role": "assistant", "content": formatted_response}) | |
except Exception as e: | |
error_msg = f"β Error generating response: {str(e)}" | |
display_history.append({"role": "assistant", "content": error_msg}) | |
# Don't add error to model history to avoid confusing the model | |
# Reset generating state | |
is_generating = False | |
# Final yield with complete response | |
yield { | |
chatbot: display_history, | |
display_history_state: display_history, | |
model_history_state: model_history, | |
is_generating_state: is_generating, | |
user_input: "", | |
submit_btn: gr.update(interactive=True, value="Send") | |
} | |
def clear_history(): | |
"""Clear both display and model histories""" | |
return { | |
chatbot: [], | |
display_history_state: [], | |
model_history_state: [], | |
is_generating_state: False, | |
user_input: "", | |
submit_btn: gr.update(interactive=True, value="Send") | |
} | |
def retry_last(display_history: list, model_history: list, system_prompt_text: str, | |
temp: float, top_p_val: float, top_k_val: int, | |
min_p_val: float, max_tokens: int): | |
""" | |
Retry the last user message with corrected history and generator handling. | |
""" | |
# Safety check: ensure there is a history and the last message was from the assistant | |
if not model_history or model_history[-1]["role"] != "assistant": | |
# If nothing to retry, yield the current state and stop | |
yield { | |
chatbot: display_history, | |
display_history_state: display_history, | |
model_history_state: model_history, | |
is_generating_state: False | |
} | |
return | |
# Remove the last assistant message from both histories | |
model_history.pop() # Remove assistant's clean message from model history | |
display_history.pop() # Remove assistant's formatted message from display history | |
# Get the last user message to resubmit it, then remove it from both histories | |
if model_history and model_history[-1]["role"] == "user": | |
last_user_msg = model_history[-1]["content"] | |
model_history.pop() # Remove user message from model history | |
display_history.pop() # Remove user message from display history | |
else: | |
# If no user message found, just return current state | |
yield { | |
chatbot: display_history, | |
display_history_state: display_history, | |
model_history_state: model_history, | |
is_generating_state: False | |
} | |
return | |
# Use 'yield from' to properly call the generator and pass its updates | |
yield from handle_user_message( | |
last_user_msg, display_history, model_history, | |
system_prompt_text, False, temp, top_p_val, top_k_val, min_p_val, max_tokens | |
) | |
def on_input_change(text, is_generating): | |
"""Handle input text changes""" | |
return gr.update(interactive=not is_generating and bool(text.strip())) | |
# Event listeners | |
submit_event = submit_btn.click( | |
handle_user_message, | |
inputs=[user_input, display_history_state, model_history_state, system_prompt, | |
is_generating_state, temperature, top_p, top_k, min_p, max_new_tokens], | |
outputs=[chatbot, display_history_state, model_history_state, is_generating_state, | |
user_input, submit_btn], | |
show_progress=True | |
) | |
submit_event_enter = user_input.submit( | |
handle_user_message, | |
inputs=[user_input, display_history_state, model_history_state, system_prompt, | |
is_generating_state, temperature, top_p, top_k, min_p, max_new_tokens], | |
outputs=[chatbot, display_history_state, model_history_state, is_generating_state, | |
user_input, submit_btn], | |
show_progress=True | |
) | |
# Clear button event | |
clear_btn.click( | |
clear_history, | |
outputs=[chatbot, display_history_state, model_history_state, is_generating_state, | |
user_input, submit_btn] | |
) | |
# Retry button event - FIXED OUTPUTS | |
retry_btn.click( | |
retry_last, | |
inputs=[display_history_state, model_history_state, system_prompt, | |
temperature, top_p, top_k, min_p, max_new_tokens], | |
outputs=[chatbot, display_history_state, model_history_state, is_generating_state], | |
show_progress=True | |
) | |
# Update submit button based on input and generation state | |
user_input.change( | |
on_input_change, | |
inputs=[user_input, is_generating_state], | |
outputs=[submit_btn] | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True, share=False) |