File size: 1,576 Bytes
076ae66
10f8cb7
986c17a
076ae66
10f8cb7
28fc680
076ae66
 
cf8cf08
076ae66
 
 
 
 
 
8426a0f
 
 
076ae66
 
28fc680
076ae66
 
8426a0f
076ae66
 
 
986c17a
8426a0f
 
 
 
 
 
 
 
 
cf8cf08
8426a0f
cf8cf08
28fc680
cf8cf08
 
 
8426a0f
28fc680
 
076ae66
 
28fc680
 
076ae66
28fc680
 
 
 
076ae66
 
8426a0f
10f8cb7
986c17a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# app.py
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import gradio as gr

# Model name
model_name = "Qwen/Qwen2.5-3B-Instruct"

# Load tokenizer and model
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(model_name)

print("Loading model...")
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

# Chat function
def respond(message, history):
    messages = [{"role": "user", "content": message}]
    prompt = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=512,
            temperature=0.7,
            top_p=0.9,
            do_sample=True,
            pad_token_id=tokenizer.eos_token_id,
        )
    response = tokenizer.decode(
        outputs[0][inputs["input_ids"].shape[-1]:],
        skip_special_tokens=True
    )
    return response

# Create Gradio ChatInterface
# Gradio 3.50.2 supports ChatInterface fully
demo = gr.ChatInterface(
    fn=respond,
    title="Qwen2.5-3B Chatbot",
    description="Ask me anything! I'm a smart AI assistant by Alibaba Cloud.",
    examples=[
        "Explain relativity in simple terms.",
        "Write a Python function to reverse a string.",
        "Solve: 2x + 8 = 20"
    ]
)

# Launch
if __name__ == "__main__":
    demo.launch()