BasilTh
commited on
Commit
Β·
77b14f6
1
Parent(s):
bbe7b0d
Deploy updated SLM customer-support chatbot
Browse files- SLM_CService.py +144 -71
- app.py +12 -7
- requirements.txt +3 -2
SLM_CService.py
CHANGED
@@ -1,60 +1,121 @@
|
|
1 |
# βββ SLM_CService.py βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
2 |
-
|
3 |
-
# Fix for libgomp warning in Spaces
|
4 |
-
os.environ["OMP_NUM_THREADS"] = "1"
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
8 |
import torch
|
9 |
|
10 |
from transformers import AutoTokenizer, BitsAndBytesConfig, pipeline
|
11 |
from peft import PeftModel
|
|
|
|
|
|
|
12 |
|
13 |
-
#
|
14 |
-
#
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
#
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
26 |
tokenizer.truncation_side = "right"
|
27 |
|
28 |
-
#
|
|
|
|
|
29 |
bnb_cfg = BitsAndBytesConfig(
|
30 |
load_in_4bit=True,
|
31 |
bnb_4bit_quant_type="nf4",
|
32 |
bnb_4bit_use_double_quant=True,
|
33 |
-
bnb_4bit_compute_dtype=torch.bfloat16
|
34 |
)
|
35 |
-
|
36 |
model = unsloth.FastLanguageModel.from_pretrained(
|
37 |
-
|
38 |
load_in_4bit=True,
|
39 |
-
quantization_config=bnb_cfg,
|
40 |
device_map="auto",
|
41 |
-
trust_remote_code=True
|
42 |
)
|
43 |
-
# 5b) Attach your LoRA adapter
|
44 |
-
model = PeftModel.from_pretrained(model, "ThomasBasil/bitext-qlora-tinyllama")
|
45 |
|
46 |
-
|
|
|
|
|
|
|
47 |
chat_pipe = pipeline(
|
48 |
"text-generation",
|
49 |
model=model,
|
50 |
tokenizer=tokenizer,
|
51 |
trust_remote_code=True,
|
52 |
return_full_text=False,
|
53 |
-
generate_kwargs={"max_new_tokens":128, "do_sample":True, "top_p":0.9, "temperature":0.7}
|
54 |
)
|
55 |
|
56 |
-
#
|
57 |
-
|
|
|
|
|
|
|
58 |
order_re = re.compile(r"#(\d{1,10})")
|
59 |
def extract_order(text: str):
|
60 |
m = order_re.search(text)
|
@@ -65,70 +126,82 @@ def handle_eta(o): return f"Delivery for order #{o} typically takes 3β5 day
|
|
65 |
def handle_track(o): return f"Track order #{o} here: https://track.example.com/{o}"
|
66 |
def handle_link(o): return f"Hereβs the latest tracking link for order #{o}: https://track.example.com/{o}"
|
67 |
def handle_return_policy(_=None):
|
68 |
-
return ("Our return policy allows returns of unused items in their original packaging "
|
69 |
-
"
|
70 |
def handle_gratitude(_=None):
|
71 |
return "Youβre welcome! Is there anything else I can help with?"
|
72 |
def handle_escalation(_=None):
|
73 |
return "Iβm sorry, I donβt have that information. Would you like me to connect you with a human agent?"
|
74 |
|
75 |
-
# 8) Core chat fn
|
76 |
stored_order = None
|
77 |
pending_intent = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
def chat_with_memory(user_input: str) -> str:
|
79 |
global stored_order, pending_intent
|
80 |
|
81 |
-
|
82 |
-
|
83 |
|
84 |
-
#
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
if new_o:
|
87 |
stored_order = new_o
|
88 |
-
if pending_intent in ("status","eta","track","link"):
|
89 |
-
fn = {"status":handle_status,"eta":handle_eta,"track":handle_track,"link":handle_link}[pending_intent]
|
90 |
reply = fn(stored_order)
|
91 |
pending_intent = None
|
92 |
-
|
93 |
return reply
|
94 |
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
if "return" in ui:
|
105 |
-
reply = handle_return_policy()
|
106 |
-
conversation_history.append(("Assistant", reply))
|
107 |
-
return reply
|
108 |
-
|
109 |
-
# E) Classify intent
|
110 |
-
if any(k in ui for k in ["status","where is my order","check status"]):
|
111 |
-
intent="status"
|
112 |
-
elif any(k in ui for k in ["how long","eta","delivery time"]):
|
113 |
-
intent="eta"
|
114 |
-
elif any(k in ui for k in ["how can i track","track my order","where is my package"]):
|
115 |
-
intent="track"
|
116 |
-
elif "tracking link" in ui or "resend" in ui:
|
117 |
-
intent="link"
|
118 |
else:
|
119 |
-
intent="fallback"
|
120 |
|
121 |
-
#
|
122 |
-
if intent in ("status","eta","track","link"):
|
123 |
if not stored_order:
|
124 |
pending_intent = intent
|
125 |
reply = "Sureβwhatβs your order number (e.g., #12345)?"
|
126 |
else:
|
127 |
-
fn = {"status":handle_status,"eta":handle_eta,"track":handle_track,"link":handle_link}[intent]
|
128 |
reply = fn(stored_order)
|
129 |
-
|
130 |
-
reply
|
131 |
|
132 |
-
#
|
133 |
-
|
|
|
|
|
|
|
134 |
return reply
|
|
|
1 |
# βββ SLM_CService.py βββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
2 |
+
# Launch-time model setup + FSM + conversational memory for the chatbot.
|
|
|
|
|
3 |
|
4 |
+
import os, shutil, zipfile
|
5 |
+
os.environ["OMP_NUM_THREADS"] = "1" # quiet libgomp noise
|
6 |
+
os.environ.pop("HF_HUB_OFFLINE", None) # avoid accidental offline mode
|
7 |
+
|
8 |
+
# 1) Unsloth must be imported before transformers
|
9 |
+
import unsloth
|
10 |
import torch
|
11 |
|
12 |
from transformers import AutoTokenizer, BitsAndBytesConfig, pipeline
|
13 |
from peft import PeftModel
|
14 |
+
from langchain.memory import ConversationBufferMemory
|
15 |
+
import gdown
|
16 |
+
import re
|
17 |
|
18 |
+
# ββ Persistent storage (HF Spaces -> Settings -> Persistent storage) βββββββββ
|
19 |
+
# Docs: /data persists across Space restarts. hf docs: persistent storage. :contentReference[oaicite:0]{index=0}
|
20 |
+
PERSIST_DIR = os.environ.get("PERSIST_DIR", "/data/slm_assets")
|
21 |
+
ADAPTER_DIR = os.path.join(PERSIST_DIR, "adapter")
|
22 |
+
TOKENIZER_DIR = os.path.join(PERSIST_DIR, "tokenizer")
|
23 |
+
ZIP_PATH = os.path.join(PERSIST_DIR, "assets.zip")
|
24 |
+
|
25 |
+
# ββ Provide Google Drive IDs as Secrets (HF Space -> Settings -> Variables) ββ
|
26 |
+
# Either one zip with both folders...
|
27 |
+
GDRIVE_ZIP_ID = os.environ.get("GDRIVE_ZIP_ID")
|
28 |
+
# ...or separate zips/files for each:
|
29 |
+
GDRIVE_ADAPTER_ID = os.environ.get("GDRIVE_ADAPTER_ID")
|
30 |
+
GDRIVE_TOKENIZER_ID = os.environ.get("GDRIVE_TOKENIZER_ID")
|
31 |
+
|
32 |
+
def _ensure_dirs():
|
33 |
+
os.makedirs(PERSIST_DIR, exist_ok=True)
|
34 |
+
os.makedirs(ADAPTER_DIR, exist_ok=True)
|
35 |
+
os.makedirs(TOKENIZER_DIR, exist_ok=True)
|
36 |
+
|
37 |
+
def _have_local_assets():
|
38 |
+
# minimal sanity checks for typical PEFT/tokenizer files
|
39 |
+
tok_ok = any(os.path.exists(os.path.join(TOKENIZER_DIR, f))
|
40 |
+
for f in ("tokenizer.json", "tokenizer.model", "tokenizer_config.json"))
|
41 |
+
lora_ok = any(os.path.exists(os.path.join(ADAPTER_DIR, f))
|
42 |
+
for f in ("adapter_config.json", "adapter_model.bin", "adapter_model.safetensors"))
|
43 |
+
return tok_ok and lora_ok
|
44 |
+
|
45 |
+
def _download_from_drive():
|
46 |
+
"""Download adapter/tokenizer from Google Drive into /data using gdown."""
|
47 |
+
_ensure_dirs()
|
48 |
+
if GDRIVE_ZIP_ID:
|
49 |
+
gdown.download(id=GDRIVE_ZIP_ID, output=ZIP_PATH, quiet=False) # gdown is built for Drive. :contentReference[oaicite:1]{index=1}
|
50 |
+
with zipfile.ZipFile(ZIP_PATH, "r") as zf:
|
51 |
+
zf.extractall(PERSIST_DIR)
|
52 |
+
return
|
53 |
+
|
54 |
+
if GDRIVE_ADAPTER_ID:
|
55 |
+
ad_zip = os.path.join(PERSIST_DIR, "adapter.zip")
|
56 |
+
gdown.download(id=GDRIVE_ADAPTER_ID, output=ad_zip, quiet=False)
|
57 |
+
try:
|
58 |
+
with zipfile.ZipFile(ad_zip, "r") as zf:
|
59 |
+
zf.extractall(ADAPTER_DIR)
|
60 |
+
except zipfile.BadZipFile:
|
61 |
+
# not a zip β assume single file
|
62 |
+
shutil.move(ad_zip, os.path.join(ADAPTER_DIR, "adapter_model.bin"))
|
63 |
+
|
64 |
+
if GDRIVE_TOKENIZER_ID:
|
65 |
+
tk_zip = os.path.join(PERSIST_DIR, "tokenizer.zip")
|
66 |
+
gdown.download(id=GDRIVE_TOKENIZER_ID, output=tk_zip, quiet=False)
|
67 |
+
try:
|
68 |
+
with zipfile.ZipFile(tk_zip, "r") as zf:
|
69 |
+
zf.extractall(TOKENIZER_DIR)
|
70 |
+
except zipfile.BadZipFile:
|
71 |
+
shutil.move(tk_zip, os.path.join(TOKENIZER_DIR, "tokenizer.json"))
|
72 |
+
|
73 |
+
# ββ Ensure local assets from Drive (first launch will download) ββββββββββββββ
|
74 |
+
if not _have_local_assets():
|
75 |
+
_download_from_drive() # persists in /data if you enabled it. :contentReference[oaicite:2]{index=2}
|
76 |
+
|
77 |
+
# ββ Tokenizer (from your Drive-backed folder) ββββββββββββββββββββββββββββββββ
|
78 |
+
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR, use_fast=False)
|
79 |
+
tokenizer.pad_token_id = tokenizer.eos_token_id
|
80 |
+
tokenizer.padding_side = "left"
|
81 |
tokenizer.truncation_side = "right"
|
82 |
|
83 |
+
# ββ Base model (4-bit) via Unsloth + your PEFT adapter ββββββββββββββββββββββ
|
84 |
+
BASE = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
|
85 |
+
|
86 |
bnb_cfg = BitsAndBytesConfig(
|
87 |
load_in_4bit=True,
|
88 |
bnb_4bit_quant_type="nf4",
|
89 |
bnb_4bit_use_double_quant=True,
|
90 |
+
bnb_4bit_compute_dtype=torch.bfloat16,
|
91 |
)
|
92 |
+
|
93 |
model = unsloth.FastLanguageModel.from_pretrained(
|
94 |
+
BASE,
|
95 |
load_in_4bit=True,
|
96 |
+
quantization_config=bnb_cfg, # prefer quantization_config over quant_type
|
97 |
device_map="auto",
|
98 |
+
trust_remote_code=True,
|
99 |
)
|
|
|
|
|
100 |
|
101 |
+
model = PeftModel.from_pretrained(model, ADAPTER_DIR)
|
102 |
+
|
103 |
+
# ββ Text-generation pipeline (use generate_kwargs, not generation_kwargs) ββββ
|
104 |
+
# Transformers pipelines accept `generate_kwargs` to forward to .generate(). :contentReference[oaicite:3]{index=3}
|
105 |
chat_pipe = pipeline(
|
106 |
"text-generation",
|
107 |
model=model,
|
108 |
tokenizer=tokenizer,
|
109 |
trust_remote_code=True,
|
110 |
return_full_text=False,
|
111 |
+
generate_kwargs={"max_new_tokens": 128, "do_sample": True, "top_p": 0.9, "temperature": 0.7},
|
112 |
)
|
113 |
|
114 |
+
# ββ Conversational memory (LangChain) ββββββββββββββββββββββββββββββββββββββββ
|
115 |
+
# ConversationBufferMemory stores full turn-by-turn chat history. :contentReference[oaicite:4]{index=4}
|
116 |
+
memory = ConversationBufferMemory(return_messages=True)
|
117 |
+
|
118 |
+
# ββ FSM helpers (your original logic, kept intact) βββββββββββββββββββββββββββ
|
119 |
order_re = re.compile(r"#(\d{1,10})")
|
120 |
def extract_order(text: str):
|
121 |
m = order_re.search(text)
|
|
|
126 |
def handle_track(o): return f"Track order #{o} here: https://track.example.com/{o}"
|
127 |
def handle_link(o): return f"Hereβs the latest tracking link for order #{o}: https://track.example.com/{o}"
|
128 |
def handle_return_policy(_=None):
|
129 |
+
return ("Our return policy allows returns of unused items in their original packaging within 30 days of receipt. "
|
130 |
+
"Would you like me to connect you with a human agent?")
|
131 |
def handle_gratitude(_=None):
|
132 |
return "Youβre welcome! Is there anything else I can help with?"
|
133 |
def handle_escalation(_=None):
|
134 |
return "Iβm sorry, I donβt have that information. Would you like me to connect you with a human agent?"
|
135 |
|
|
|
136 |
stored_order = None
|
137 |
pending_intent = None
|
138 |
+
|
139 |
+
def _history_to_prompt(user_input: str) -> str:
|
140 |
+
"""Build a prompt from LangChain memory turns for fallback generation."""
|
141 |
+
hist = memory.load_memory_variables({}).get("chat_history", [])
|
142 |
+
prompt = "You are a helpful support assistant.\n"
|
143 |
+
for msg in hist:
|
144 |
+
# LangChain messages expose a .type like 'human'/'ai' in many versions
|
145 |
+
mtype = getattr(msg, "type", "")
|
146 |
+
role = "User" if mtype == "human" else "Assistant"
|
147 |
+
content = getattr(msg, "content", "")
|
148 |
+
prompt += f"{role}: {content}\n"
|
149 |
+
prompt += f"User: {user_input}\nAssistant: "
|
150 |
+
return prompt
|
151 |
+
|
152 |
def chat_with_memory(user_input: str) -> str:
|
153 |
global stored_order, pending_intent
|
154 |
|
155 |
+
ui = user_input.strip()
|
156 |
+
low = ui.lower()
|
157 |
|
158 |
+
# A) quick intent short-circuits
|
159 |
+
if any(tok in low for tok in ["thank you", "thanks", "thx"]):
|
160 |
+
reply = handle_gratitude()
|
161 |
+
memory.save_context({"input": ui}, {"output": reply})
|
162 |
+
return reply
|
163 |
+
if "return" in low:
|
164 |
+
reply = handle_return_policy()
|
165 |
+
memory.save_context({"input": ui}, {"output": reply})
|
166 |
+
return reply
|
167 |
+
|
168 |
+
# B) order number?
|
169 |
+
new_o = extract_order(ui)
|
170 |
if new_o:
|
171 |
stored_order = new_o
|
172 |
+
if pending_intent in ("status", "eta", "track", "link"):
|
173 |
+
fn = {"status": handle_status, "eta": handle_eta, "track": handle_track, "link": handle_link}[pending_intent]
|
174 |
reply = fn(stored_order)
|
175 |
pending_intent = None
|
176 |
+
memory.save_context({"input": ui}, {"output": reply})
|
177 |
return reply
|
178 |
|
179 |
+
# C) intent classification
|
180 |
+
if any(k in low for k in ["status", "where is my order", "check status"]):
|
181 |
+
intent = "status"
|
182 |
+
elif any(k in low for k in ["how long", "eta", "delivery time"]):
|
183 |
+
intent = "eta"
|
184 |
+
elif any(k in low for k in ["how can i track", "track my order", "where is my package"]):
|
185 |
+
intent = "track"
|
186 |
+
elif "tracking link" in low or "resend" in low:
|
187 |
+
intent = "link"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
188 |
else:
|
189 |
+
intent = "fallback"
|
190 |
|
191 |
+
# D) handle core intents (ask for order first if needed)
|
192 |
+
if intent in ("status", "eta", "track", "link"):
|
193 |
if not stored_order:
|
194 |
pending_intent = intent
|
195 |
reply = "Sureβwhatβs your order number (e.g., #12345)?"
|
196 |
else:
|
197 |
+
fn = {"status": handle_status, "eta": handle_eta, "track": handle_track, "link": handle_link}[intent]
|
198 |
reply = fn(stored_order)
|
199 |
+
memory.save_context({"input": ui}, {"output": reply})
|
200 |
+
return reply
|
201 |
|
202 |
+
# E) fallback β generate with chat history context
|
203 |
+
prompt = _history_to_prompt(ui)
|
204 |
+
out = chat_pipe(prompt)[0]["generated_text"]
|
205 |
+
reply = out.split("Assistant:")[-1].strip()
|
206 |
+
memory.save_context({"input": ui}, {"output": reply})
|
207 |
return reply
|
app.py
CHANGED
@@ -1,23 +1,28 @@
|
|
1 |
import os
|
2 |
-
os.environ["OMP_NUM_THREADS"] = "1" #
|
3 |
|
4 |
import gradio as gr
|
5 |
from SLM_CService import chat_with_memory
|
6 |
|
7 |
def respond(user_message, history):
|
|
|
|
|
8 |
bot_reply = chat_with_memory(user_message)
|
9 |
-
history = history + [(user_message, bot_reply)]
|
10 |
return history, history
|
11 |
|
12 |
with gr.Blocks() as demo:
|
13 |
gr.Markdown("# π Customer Support Chatbot")
|
14 |
-
chatbot = gr.Chatbot()
|
15 |
with gr.Row():
|
16 |
-
user_in = gr.Textbox(placeholder="Type your message here...")
|
17 |
-
|
18 |
-
reset
|
19 |
-
|
|
|
20 |
reset.click(lambda: ([], []), None, [chatbot, chatbot])
|
|
|
|
|
21 |
|
22 |
if __name__ == "__main__":
|
23 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
1 |
import os
|
2 |
+
os.environ["OMP_NUM_THREADS"] = "1" # silence OpenMP spam
|
3 |
|
4 |
import gradio as gr
|
5 |
from SLM_CService import chat_with_memory
|
6 |
|
7 |
def respond(user_message, history):
|
8 |
+
if not user_message:
|
9 |
+
return history, history
|
10 |
bot_reply = chat_with_memory(user_message)
|
11 |
+
history = (history or []) + [(user_message, bot_reply)]
|
12 |
return history, history
|
13 |
|
14 |
with gr.Blocks() as demo:
|
15 |
gr.Markdown("# π Customer Support Chatbot")
|
16 |
+
chatbot = gr.Chatbot()
|
17 |
with gr.Row():
|
18 |
+
user_in = gr.Textbox(placeholder="Type your message here...", scale=5)
|
19 |
+
send = gr.Button("Send", variant="primary")
|
20 |
+
reset = gr.Button("π Reset Chat")
|
21 |
+
|
22 |
+
send.click(respond, [user_in, chatbot], [chatbot, chatbot])
|
23 |
reset.click(lambda: ([], []), None, [chatbot, chatbot])
|
24 |
+
# Optional: submit on enter
|
25 |
+
user_in.submit(respond, [user_in, chatbot], [chatbot, chatbot])
|
26 |
|
27 |
if __name__ == "__main__":
|
28 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
requirements.txt
CHANGED
@@ -1,11 +1,12 @@
|
|
1 |
-
gradio==5.41.1
|
2 |
transformers
|
3 |
torch
|
4 |
sentencepiece
|
5 |
-
langchain
|
6 |
bitsandbytes
|
7 |
peft
|
8 |
xformers
|
9 |
unsloth
|
10 |
unsloth_zoo
|
11 |
huggingface_hub
|
|
|
|
1 |
+
gradio==5.41.1
|
2 |
transformers
|
3 |
torch
|
4 |
sentencepiece
|
5 |
+
langchain
|
6 |
bitsandbytes
|
7 |
peft
|
8 |
xformers
|
9 |
unsloth
|
10 |
unsloth_zoo
|
11 |
huggingface_hub
|
12 |
+
gdown
|