Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
import torch | |
import spaces | |
import subprocess | |
import sys | |
# Install specific transformers version | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "transformers==4.48.3"]) | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Load model and tokenizer | |
model_name = "nvidia/NVIDIA-Nemotron-Nano-9B-v2" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = None | |
def load_model(): | |
global model | |
if model is None: | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.bfloat16, | |
trust_remote_code=True, | |
device_map="auto" | |
) | |
return model | |
def generate_response(message, history, enable_reasoning, temperature, top_p, max_tokens): | |
"""Generate response from the model""" | |
# Prepare messages with reasoning control | |
messages = [] | |
# Add system message based on reasoning setting | |
if enable_reasoning: | |
messages.append({"role": "system", "content": "/think"}) | |
else: | |
messages.append({"role": "system", "content": "/no_think"}) | |
# Add conversation history | |
for user_msg, assistant_msg in history: | |
messages.append({"role": "user", "content": user_msg}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": assistant_msg}) | |
# Add current message | |
messages.append({"role": "user", "content": message}) | |
# Load model if needed | |
model = load_model() | |
# Tokenize the conversation | |
tokenized_chat = tokenizer.apply_chat_template( | |
messages, | |
tokenize=True, | |
add_generation_prompt=True, | |
return_tensors="pt" | |
).to(model.device) | |
# Set generation parameters based on reasoning mode | |
if enable_reasoning: | |
# Recommended settings for reasoning | |
generation_kwargs = { | |
"temperature": temperature if temperature > 0 else 0.6, | |
"top_p": top_p if top_p < 1 else 0.95, | |
"do_sample": True, | |
"max_new_tokens": max_tokens, | |
"eos_token_id": tokenizer.eos_token_id | |
} | |
else: | |
# Greedy search for non-reasoning | |
generation_kwargs = { | |
"do_sample": False, | |
"max_new_tokens": max_tokens, | |
"eos_token_id": tokenizer.eos_token_id | |
} | |
# Generate response | |
with torch.no_grad(): | |
outputs = model.generate(tokenized_chat, **generation_kwargs) | |
# Decode and extract the assistant's response | |
generated_tokens = outputs[0][tokenized_chat.shape[-1]:] # Get only new tokens | |
response = tokenizer.decode(generated_tokens, skip_special_tokens=True) | |
return response | |
# Create Gradio interface | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# NVIDIA Nemotron Nano 9B v2 Chatbot | |
This chatbot uses the NVIDIA Nemotron Nano 9B v2 model with optional reasoning capabilities. | |
- **Enable Reasoning**: Activates the model's chain-of-thought reasoning (/think mode) | |
- **Disable Reasoning**: Uses direct response generation (/no_think mode) | |
**Note:** Using transformers version 4.48.3 as recommended by the model documentation. | |
""" | |
) | |
chatbot = gr.Chatbot(height=500) | |
msg = gr.Textbox( | |
label="Message", | |
placeholder="Type your message here...", | |
lines=2 | |
) | |
with gr.Row(): | |
submit = gr.Button("Send", variant="primary") | |
clear = gr.Button("Clear") | |
with gr.Accordion("Advanced Settings", open=False): | |
enable_reasoning = gr.Checkbox( | |
label="Enable Reasoning (/think mode)", | |
value=True, | |
info="Enable chain-of-thought reasoning for complex queries" | |
) | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=2.0, | |
value=0.6, | |
step=0.1, | |
label="Temperature", | |
info="Controls randomness (recommended: 0.6 for reasoning, ignored for non-reasoning)" | |
) | |
top_p = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.95, | |
step=0.05, | |
label="Top-p", | |
info="Controls diversity (recommended: 0.95 for reasoning, ignored for non-reasoning)" | |
) | |
max_tokens = gr.Slider( | |
minimum=32, | |
maximum=2048, | |
value=1024, | |
step=32, | |
label="Max New Tokens", | |
info="Maximum number of tokens to generate (recommended: 1024+ for reasoning)" | |
) | |
def user_submit(message, history): | |
return "", history + [[message, None]] | |
def bot_response(history, enable_reasoning, temperature, top_p, max_tokens): | |
if not history: | |
return history | |
message = history[-1][0] | |
try: | |
response = generate_response( | |
message, | |
history[:-1], | |
enable_reasoning, | |
temperature, | |
top_p, | |
max_tokens | |
) | |
history[-1][1] = response | |
except Exception as e: | |
history[-1][1] = f"Error generating response: {str(e)}" | |
return history | |
msg.submit( | |
user_submit, | |
[msg, chatbot], | |
[msg, chatbot], | |
queue=False | |
).then( | |
bot_response, | |
[chatbot, enable_reasoning, temperature, top_p, max_tokens], | |
chatbot | |
) | |
submit.click( | |
user_submit, | |
[msg, chatbot], | |
[msg, chatbot], | |
queue=False | |
).then( | |
bot_response, | |
[chatbot, enable_reasoning, temperature, top_p, max_tokens], | |
chatbot | |
) | |
clear.click(lambda: None, None, chatbot, queue=False) | |
# Example prompts | |
gr.Examples( | |
examples=[ | |
"Write a haiku about GPUs", | |
"Explain quantum computing in simple terms", | |
"What is the capital of France?", | |
"Solve this step by step: If a train travels 120 miles in 2 hours, what is its average speed?", | |
"Write a short story about a robot learning to paint" | |
], | |
inputs=msg | |
) | |
if __name__ == "__main__": | |
demo.launch() |