Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -3,16 +3,21 @@ from pydantic import BaseModel
|
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
import torch
|
5 |
|
6 |
-
# 初始化 Qwen 模型與 tokenizer
|
|
|
|
|
|
|
|
|
|
|
7 |
tokenizer = AutoTokenizer.from_pretrained(
|
8 |
-
|
9 |
trust_remote_code=True
|
10 |
)
|
11 |
model = AutoModelForCausalLM.from_pretrained(
|
12 |
-
|
13 |
trust_remote_code=True,
|
14 |
torch_dtype=torch.float32
|
15 |
-
).to(
|
16 |
|
17 |
# 建立 FastAPI 應用
|
18 |
app = FastAPI()
|
@@ -26,8 +31,11 @@ class Prompt(BaseModel):
|
|
26 |
@app.post("/chat")
|
27 |
async def chat(prompt: Prompt):
|
28 |
global chat_history
|
|
|
|
|
29 |
if prompt.reset:
|
30 |
chat_history = []
|
|
|
31 |
|
32 |
chat_history.append({"role": "user", "content": prompt.text})
|
33 |
|
@@ -38,7 +46,7 @@ async def chat(prompt: Prompt):
|
|
38 |
chatml += "<|im_start|>assistant\n"
|
39 |
|
40 |
try:
|
41 |
-
inputs = tokenizer(chatml, return_tensors="pt").to(
|
42 |
outputs = model.generate(
|
43 |
**inputs,
|
44 |
max_new_tokens=512,
|
@@ -46,11 +54,24 @@ async def chat(prompt: Prompt):
|
|
46 |
temperature=0.7,
|
47 |
top_p=0.9
|
48 |
)
|
49 |
-
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
chat_history.append({"role": "assistant", "content": reply})
|
|
|
53 |
return {"reply": reply}
|
|
|
54 |
except Exception as e:
|
55 |
print("❌ 模型回應錯誤:", e)
|
56 |
return {"reply": "目前無法取得模型回覆,請稍後再試。"}
|
|
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
4 |
import torch
|
5 |
|
6 |
+
# 初始化 Qwen 模型與 tokenizer(加上 trust_remote_code)
|
7 |
+
model_id = "Qwen/Qwen-1_8B-Chat"
|
8 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
9 |
+
|
10 |
+
print(f"🚀 載入模型:{model_id} on {device}")
|
11 |
+
|
12 |
tokenizer = AutoTokenizer.from_pretrained(
|
13 |
+
model_id,
|
14 |
trust_remote_code=True
|
15 |
)
|
16 |
model = AutoModelForCausalLM.from_pretrained(
|
17 |
+
model_id,
|
18 |
trust_remote_code=True,
|
19 |
torch_dtype=torch.float32
|
20 |
+
).to(device)
|
21 |
|
22 |
# 建立 FastAPI 應用
|
23 |
app = FastAPI()
|
|
|
31 |
@app.post("/chat")
|
32 |
async def chat(prompt: Prompt):
|
33 |
global chat_history
|
34 |
+
|
35 |
+
print(f"\n📝 使用者輸入:{prompt.text}")
|
36 |
if prompt.reset:
|
37 |
chat_history = []
|
38 |
+
print("🔄 Chat history 已重置")
|
39 |
|
40 |
chat_history.append({"role": "user", "content": prompt.text})
|
41 |
|
|
|
46 |
chatml += "<|im_start|>assistant\n"
|
47 |
|
48 |
try:
|
49 |
+
inputs = tokenizer(chatml, return_tensors="pt").to(device)
|
50 |
outputs = model.generate(
|
51 |
**inputs,
|
52 |
max_new_tokens=512,
|
|
|
54 |
temperature=0.7,
|
55 |
top_p=0.9
|
56 |
)
|
57 |
+
response = tokenizer.decode(outputs[0], skip_special_tokens=True).strip()
|
58 |
+
|
59 |
+
print("🧠 原始模型回覆:", response)
|
60 |
+
|
61 |
+
# 擷取 assistant 回覆內容
|
62 |
+
if "<|im_start|>assistant\n" in response:
|
63 |
+
reply = response.split("<|im_end|>")[0].split("<|im_start|>assistant\n")[-1].strip()
|
64 |
+
else:
|
65 |
+
reply = response # fallback
|
66 |
+
|
67 |
+
if not reply:
|
68 |
+
reply = "⚠️ 模型未產生回覆,請稍後再試。"
|
69 |
+
print("⚠️ 回覆為空字串")
|
70 |
|
71 |
chat_history.append({"role": "assistant", "content": reply})
|
72 |
+
print("✅ 最終回覆:", reply)
|
73 |
return {"reply": reply}
|
74 |
+
|
75 |
except Exception as e:
|
76 |
print("❌ 模型回應錯誤:", e)
|
77 |
return {"reply": "目前無法取得模型回覆,請稍後再試。"}
|