arad1367 commited on
Commit
8426a0f
·
verified ·
1 Parent(s): c214a24

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -35
app.py CHANGED
@@ -3,64 +3,69 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
  import gradio as gr
5
 
6
- # Load model and tokenizer
7
  model_name = "Qwen/Qwen2.5-3B-Instruct"
8
 
 
9
  print("Loading tokenizer...")
10
  tokenizer = AutoTokenizer.from_pretrained(model_name)
11
 
 
12
  print("Loading model...")
13
  model = AutoModelForCausalLM.from_pretrained(
14
  model_name,
15
- torch_dtype=torch.bfloat16, # Use bfloat16 to save memory and speed up inference
16
- device_map="auto", # Automatically use GPU if available
17
- trust_remote_code=True # Required for Qwen models
 
 
18
  )
19
 
20
- # Define chat function
21
  def respond(message, history):
 
22
  messages = [{"role": "user", "content": message}]
23
-
24
- # Apply chat template
25
- text = tokenizer.apply_chat_template(
26
  messages,
27
  tokenize=False,
28
  add_generation_prompt=True
29
  )
30
-
31
- # Tokenize input
32
- model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
33
-
34
- # Generate response
35
- generated_ids = model.generate(
36
- **model_inputs,
37
- max_new_tokens=512,
38
- do_sample=True,
39
- temperature=0.7,
40
- top_p=0.9,
41
- repetition_penalty=1.1
42
- )
43
-
44
- # Extract only the new tokens
45
- generated_ids = generated_ids[0][model_inputs.input_ids.shape[-1]:]
46
- response = tokenizer.decode(generated_ids, skip_special_tokens=True)
47
-
48
- return response
49
 
50
- # Create Gradio chat interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  demo = gr.ChatInterface(
52
  fn=respond,
53
- title="Qwen2.5-3B Chatbot",
54
- description="Chat with Qwen2.5-3B-Instruct, a powerful 3-billion-parameter LLM by Alibaba Cloud.",
55
  examples=[
56
  "Explain quantum computing in simple terms.",
57
- "Write a Python function to calculate Fibonacci numbers.",
58
- "Tell me a joke about AI."
 
59
  ],
60
- retry_btn=None,
61
- undo_btn=None,
 
 
62
  )
63
 
64
- # Launch the app
65
  if __name__ == "__main__":
66
  demo.launch()
 
3
  import torch
4
  import gradio as gr
5
 
6
+ # Model ID
7
  model_name = "Qwen/Qwen2.5-3B-Instruct"
8
 
9
+ # Load tokenizer
10
  print("Loading tokenizer...")
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
 
13
+ # Load model with bfloat16 and device_map for efficient GPU usage
14
  print("Loading model...")
15
  model = AutoModelForCausalLM.from_pretrained(
16
  model_name,
17
+ torch_dtype=torch.bfloat16,
18
+ device_map="auto",
19
+ trust_remote_code=True,
20
+ # Optional: use 4-bit quantization to save VRAM
21
+ # quantization_config=transformers.BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
22
  )
23
 
24
+ # Chatbot function
25
  def respond(message, history):
26
+ # Format message with chat template
27
  messages = [{"role": "user", "content": message}]
28
+ prompt = tokenizer.apply_chat_template(
 
 
29
  messages,
30
  tokenize=False,
31
  add_generation_prompt=True
32
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ # Tokenize
35
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
36
+
37
+ # Generate
38
+ with torch.no_grad():
39
+ outputs = model.generate(
40
+ **inputs,
41
+ max_new_tokens=512,
42
+ temperature=0.7,
43
+ top_p=0.9,
44
+ do_sample=True,
45
+ pad_token_id=tokenizer.eos_token_id
46
+ )
47
+
48
+ # Decode only the response part
49
+ full_response = tokenizer.decode(outputs[0][inputs["input_ids"].shape[-1]:], skip_special_tokens=True)
50
+ return full_response
51
+
52
+ # Create Gradio ChatInterface
53
  demo = gr.ChatInterface(
54
  fn=respond,
55
+ title="💬 Qwen2.5-3B-Instruct Chatbot",
56
+ description="A smart, open-source chatbot powered by Qwen2.5-3B-Instruct. Ask anything!",
57
  examples=[
58
  "Explain quantum computing in simple terms.",
59
+ "Write a Python function to check if a number is prime.",
60
+ "Solve: 3x + 5 = 17",
61
+ "Tell me a fun fact about space."
62
  ],
63
+ # ✅ These are now supported with updated Gradio
64
+ retry_btn=None, # Hides retry button
65
+ undo_btn=None, # Hides undo button
66
+ clear_btn=None # Optional: hide clear button too
67
  )
68
 
69
+ # Launch
70
  if __name__ == "__main__":
71
  demo.launch()