arad1367's picture
Update app.py
cf8cf08 verified
raw
history blame
1.75 kB
# app.py
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr
# Model identifier
model_name = "Qwen/Qwen2.5-3B-Instruct"
# Load tokenizer and model
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)
print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True,
)
# Chat function (no history used for simplicity and compatibility)
def respond(message, history):
messages = [{"role": "user", "content": message}]
# Apply chat template
prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
# Tokenize
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=512,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=tokenizer.eos_token_id,
)
# Decode response
response = tokenizer.decode(
outputs[0][inputs['input_ids'].shape[-1]:],
skip_special_tokens=True
)
return response
# Gradio Interface — NO retry_btn / undo_btn (to avoid version issues)
demo = gr.ChatInterface(
fn=respond,
title="Qwen2.5-3B-Instruct Chatbot",
description="Ask me anything! I'm a 3B AI assistant by Alibaba Cloud.",
examples=[
"Explain quantum computing in simple terms.",
"Write a Python function to check if a number is prime.",
"Solve: 3x + 5 = 17"
],
)
# Launch
if __name__ == "__main__":
demo.launch()