BasilTh commited on
Commit
da2916f
Β·
1 Parent(s): ae5323d

Deploy updated SLM customer-support chatbot

Browse files
Files changed (2) hide show
  1. SLM_CService.py +71 -89
  2. app.py +7 -4
SLM_CService.py CHANGED
@@ -1,5 +1,5 @@
1
  # ── SLM_CService.py ───────────────────────────────────────────────────────────
2
- # Customer-support-only chatbot with strict NSFW blocking + domain guardrails.
3
 
4
  import os
5
  import re
@@ -16,26 +16,17 @@ from peft import PeftModel
16
  from langchain.memory import ConversationBufferMemory
17
 
18
  # ──────────────────────────────────────────────────────────────────────────────
19
- # Hub repos
20
- REPO = "ThomasBasil/bitext-qlora-tinyllama" # adapter + tokenizer at ROOT
21
- BASE = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # base model
22
 
23
- # Generation params (pass at call time)
24
  GEN_KW = dict(
25
- max_new_tokens=160,
26
- do_sample=True,
27
- top_p=0.9,
28
- temperature=0.7,
29
- repetition_penalty=1.1,
30
- no_repeat_ngram_size=4,
31
  )
32
 
33
- # 4-bit NF4 (GPU needed)
34
  bnb_cfg = BitsAndBytesConfig(
35
- load_in_4bit=True,
36
- bnb_4bit_quant_type="nf4",
37
- bnb_4bit_use_double_quant=True,
38
- bnb_4bit_compute_dtype=torch.float16, # T4/A10G-friendly
39
  )
40
 
41
  # ---- Tokenizer & model -------------------------------------------------------
@@ -45,108 +36,79 @@ if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
45
  tokenizer.padding_side = "left"
46
  tokenizer.truncation_side = "right"
47
 
48
- # Unsloth returns (model, tokenizer) β†’ unpack
49
  model, _ = unsloth.FastLanguageModel.from_pretrained(
50
- model_name=BASE,
51
- load_in_4bit=True,
52
- quantization_config=bnb_cfg,
53
- device_map="auto",
54
- trust_remote_code=True,
55
  )
56
  unsloth.FastLanguageModel.for_inference(model)
57
-
58
- # Apply your PEFT adapter from repo root
59
  model = PeftModel.from_pretrained(model, REPO)
60
  model.eval()
61
 
62
- # Text-generation pipeline (we pass gen params at call time)
63
  chat_pipe = pipeline(
64
- "text-generation",
65
- model=model,
66
- tokenizer=tokenizer,
67
- trust_remote_code=True,
68
- return_full_text=False,
69
  )
70
 
71
  # ──────────────────────────────────────────────────────────────────────────────
72
- # Moderation & blocking (strict)
73
  from transformers import TextClassificationPipeline
74
-
75
  SEXUAL_TERMS = [
76
  "sex","sexual","porn","nsfw","fetish","kink","bdsm","nude","naked","anal",
77
  "blowjob","handjob","cum","breast","boobs","vagina","penis","semen","ejaculate",
78
  "doggy","missionary","cowgirl","69","kamasutra","dominatrix","submissive","spank",
79
  "sex position","have sex","make love","how to flirt","dominant in bed",
80
  ]
81
-
82
- def _bad_words_ids(tokenizer, terms: List[str]) -> List[List[int]]:
83
- ids = set()
84
  for t in terms:
85
- for v in (t, " " + t):
86
- toks = tokenizer(v, add_special_tokens=False).input_ids
87
- if toks:
88
- ids.add(tuple(toks))
89
  return [list(t) for t in ids]
90
-
91
  BAD_WORD_IDS = _bad_words_ids(tokenizer, SEXUAL_TERMS)
92
 
93
  nsfw_cls: TextClassificationPipeline = pipeline(
94
- "text-classification",
95
- model="eliasalbouzidi/distilbert-nsfw-text-classifier",
96
- truncation=True,
97
  )
98
-
99
  toxicity_cls: TextClassificationPipeline = pipeline(
100
- "text-classification",
101
- model="unitary/toxic-bert",
102
- truncation=True,
103
- return_all_scores=True,
104
  )
105
-
106
  def is_sexual_or_toxic(text: str) -> bool:
107
  t = (text or "").lower()
108
- if any(k in t for k in SEXUAL_TERMS):
109
- return True
110
  try:
111
  res = nsfw_cls(t)[0]
112
- if (res.get("label","").lower() == "nsfw") and float(res.get("score",0)) > 0.60:
113
- return True
114
- except Exception:
115
- pass
116
  try:
117
  scores = toxicity_cls(t)[0]
118
- if any(item["score"] > 0.60 and item["label"].lower() in
119
- {"toxic","severe_toxic","obscene","threat","insult","identity_hate"} for item in scores):
120
  return True
121
- except Exception:
122
- pass
123
  return False
124
-
125
  REFUSAL = ("Sorry, I can’t help with that. I’m only for store support "
126
  "(orders, shipping, ETA, tracking, returns, warranty, account).")
127
 
128
  # ──────────────────────────────────────────────────────────────────────────────
129
- # Memory (kept simple)
130
- memory = ConversationBufferMemory(return_messages=True)
131
-
132
- # System prompt = domain guardrails
133
  SYSTEM_PROMPT = (
134
- "You are a customer-support assistant for our store. "
135
- "Only handle account, orders, shipping, delivery ETA, tracking links, returns/refunds, warranty, and store policy. "
136
  "If a request is out of scope or sexual/NSFW, refuse briefly and offer support options. "
137
  "Be concise and professional."
138
  )
139
-
140
- # Allowed support-ish keywords for routing
141
  ALLOWED_KEYWORDS = (
142
  "order","track","status","delivery","shipping","ship","eta","arrive",
143
  "refund","return","exchange","warranty","guarantee","account","billing",
144
  "address","cancel","policy","help","support","agent","human"
145
  )
146
 
147
- # FSM helpers
148
  order_re = re.compile(r"#(\d{1,10})")
149
- def extract_order(text: str): m = order_re.search(text); return m.group(1) if m else None
 
 
150
  def handle_status(o): return f"Order #{o} is in transit and should arrive in 3–5 business days."
151
  def handle_eta(o): return f"Delivery for order #{o} typically takes 3–5 days; you can track it at https://track.example.com/{o}"
152
  def handle_track(o): return f"Track order #{o} here: https://track.example.com/{o}"
@@ -154,16 +116,31 @@ def handle_link(o): return f"Here’s the latest tracking link for order #{o}:
154
  def handle_return_policy(_=None):
155
  return ("Our return policy allows returns of unused items in original packaging within 30 days of receipt. "
156
  "Would you like me to connect you with a human agent?")
 
 
 
157
  def handle_gratitude(_=None): return "You’re welcome! Anything else I can help with?"
158
  def handle_escalation(_=None): return "I can connect you with a human agent. Would you like me to do that?"
159
 
 
160
  stored_order = None
161
  pending_intent = None
162
 
 
 
 
 
 
 
 
 
 
 
163
  # ---- chat templating ---------------------------------------------------------
164
  def _lc_to_messages() -> List[Dict[str,str]]:
165
  msgs = [{"role": "system", "content": SYSTEM_PROMPT}]
166
- for m in memory.load_memory_variables({}).get("chat_history", []):
 
167
  role = "user" if getattr(m, "type", "") == "human" else "assistant"
168
  msgs.append({"role": role, "content": getattr(m, "content", "")})
169
  return msgs
@@ -187,7 +164,13 @@ def chat_with_memory(user_input: str) -> str:
187
  if not ui:
188
  return "How can I help with your order today?"
189
 
190
- # 1) Safety first
 
 
 
 
 
 
191
  if is_sexual_or_toxic(ui):
192
  reply = REFUSAL
193
  memory.save_context({"input": ui}, {"output": reply})
@@ -195,7 +178,7 @@ def chat_with_memory(user_input: str) -> str:
195
 
196
  low = ui.lower()
197
 
198
- # 2) Quick intents that don't depend on domain keywords
199
  if any(tok in low for tok in ["thank you","thanks","thx"]):
200
  reply = handle_gratitude()
201
  memory.save_context({"input": ui}, {"output": reply})
@@ -205,26 +188,24 @@ def chat_with_memory(user_input: str) -> str:
205
  memory.save_context({"input": ui}, {"output": reply})
206
  return reply
207
 
208
- # 3) *** ORDER NUMBER FIRST *** (so follow-ups like "It's #4567" work)
209
  new_o = extract_order(ui)
210
  if new_o:
211
  stored_order = new_o
212
- if pending_intent in ("status","eta","track","link"):
213
- fn = {"status": handle_status,"eta": handle_eta,"track": handle_track,"link": handle_link}[pending_intent]
214
- reply = fn(stored_order)
215
- pending_intent = None
216
- memory.save_context({"input": ui}, {"output": reply})
217
- return reply
218
- # No pending intent β†’ fall through to classify what they want next.
219
 
220
- # 4) Support-only guard (but SKIP if we just saw an order number or have a pending intent)
221
  if pending_intent is None and new_o is None:
222
  if not any(k in low for k in ALLOWED_KEYWORDS) and not any(k in low for k in ("hi","hello","hey")):
223
  reply = "I’m for store support only (orders, shipping, returns, warranty, account). How can I help with those?"
224
  memory.save_context({"input": ui}, {"output": reply})
225
  return reply
226
 
227
- # 5) Intent classification
228
  if any(k in low for k in ["status","where is my order","check status"]):
229
  intent = "status"
230
  elif any(k in low for k in ["how long","eta","delivery time"]):
@@ -233,23 +214,24 @@ def chat_with_memory(user_input: str) -> str:
233
  intent = "track"
234
  elif "tracking link" in low or "resend" in low or "link" in low:
235
  intent = "link"
 
 
236
  else:
237
  intent = "fallback"
238
 
239
- # 6) Handle core intents
240
- if intent in ("status","eta","track","link"):
241
  if not stored_order:
242
  pending_intent = intent
243
  reply = "Sureβ€”what’s your order number (e.g., #12345)?"
244
  else:
245
- fn = {"status": handle_status,"eta": handle_eta,"track": handle_track,"link": handle_link}[intent]
 
246
  reply = fn(stored_order)
247
  memory.save_context({"input": ui}, {"output": reply})
248
  return reply
249
 
250
- # 7) LLM fallback (still on-topic) + post-check for safety
251
  reply = _generate_reply(ui)
252
- if is_sexual_or_toxic(reply):
253
- reply = REFUSAL
254
  memory.save_context({"input": ui}, {"output": reply})
255
  return reply
 
1
  # ── SLM_CService.py ───────────────────────────────────────────────────────────
2
+ # Customer-support-only chatbot with strict NSFW blocking + proper Reset.
3
 
4
  import os
5
  import re
 
16
  from langchain.memory import ConversationBufferMemory
17
 
18
  # ──────────────────────────────────────────────────────────────────────────────
19
+ REPO = "ThomasBasil/bitext-qlora-tinyllama"
20
+ BASE = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
 
21
 
 
22
  GEN_KW = dict(
23
+ max_new_tokens=160, do_sample=True, top_p=0.9, temperature=0.7,
24
+ repetition_penalty=1.1, no_repeat_ngram_size=4,
 
 
 
 
25
  )
26
 
 
27
  bnb_cfg = BitsAndBytesConfig(
28
+ load_in_4bit=True, bnb_4bit_quant_type="nf4",
29
+ bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.float16,
 
 
30
  )
31
 
32
  # ---- Tokenizer & model -------------------------------------------------------
 
36
  tokenizer.padding_side = "left"
37
  tokenizer.truncation_side = "right"
38
 
 
39
  model, _ = unsloth.FastLanguageModel.from_pretrained(
40
+ model_name=BASE, load_in_4bit=True, quantization_config=bnb_cfg,
41
+ device_map="auto", trust_remote_code=True,
 
 
 
42
  )
43
  unsloth.FastLanguageModel.for_inference(model)
 
 
44
  model = PeftModel.from_pretrained(model, REPO)
45
  model.eval()
46
 
 
47
  chat_pipe = pipeline(
48
+ "text-generation", model=model, tokenizer=tokenizer,
49
+ trust_remote_code=True, return_full_text=False,
 
 
 
50
  )
51
 
52
  # ──────────────────────────────────────────────────────────────────────────────
53
+ # Moderation (unchanged from your last good version)
54
  from transformers import TextClassificationPipeline
 
55
  SEXUAL_TERMS = [
56
  "sex","sexual","porn","nsfw","fetish","kink","bdsm","nude","naked","anal",
57
  "blowjob","handjob","cum","breast","boobs","vagina","penis","semen","ejaculate",
58
  "doggy","missionary","cowgirl","69","kamasutra","dominatrix","submissive","spank",
59
  "sex position","have sex","make love","how to flirt","dominant in bed",
60
  ]
61
+ def _bad_words_ids(tok, terms: List[str]) -> List[List[int]]:
62
+ ids=set()
 
63
  for t in terms:
64
+ for v in (t, " "+t):
65
+ toks = tok(v, add_special_tokens=False).input_ids
66
+ if toks: ids.add(tuple(toks))
 
67
  return [list(t) for t in ids]
 
68
  BAD_WORD_IDS = _bad_words_ids(tokenizer, SEXUAL_TERMS)
69
 
70
  nsfw_cls: TextClassificationPipeline = pipeline(
71
+ "text-classification", model="eliasalbouzidi/distilbert-nsfw-text-classifier", truncation=True,
 
 
72
  )
 
73
  toxicity_cls: TextClassificationPipeline = pipeline(
74
+ "text-classification", model="unitary/toxic-bert", truncation=True, return_all_scores=True,
 
 
 
75
  )
 
76
  def is_sexual_or_toxic(text: str) -> bool:
77
  t = (text or "").lower()
78
+ if any(k in t for k in SEXUAL_TERMS): return True
 
79
  try:
80
  res = nsfw_cls(t)[0]
81
+ if (res.get("label","").lower()=="nsfw") and float(res.get("score",0))>0.60: return True
82
+ except Exception: pass
 
 
83
  try:
84
  scores = toxicity_cls(t)[0]
85
+ if any(s["score"]>0.60 and s["label"].lower() in
86
+ {"toxic","severe_toxic","obscene","threat","insult","identity_hate"} for s in scores):
87
  return True
88
+ except Exception: pass
 
89
  return False
 
90
  REFUSAL = ("Sorry, I can’t help with that. I’m only for store support "
91
  "(orders, shipping, ETA, tracking, returns, warranty, account).")
92
 
93
  # ──────────────────────────────────────────────────────────────────────────────
94
+ # Memory + globals
95
+ memory = ConversationBufferMemory(return_messages=True) # has .clear() :contentReference[oaicite:2]{index=2}
 
 
96
  SYSTEM_PROMPT = (
97
+ "You are a customer-support assistant for our store. Only handle account, "
98
+ "orders, shipping, delivery ETA, tracking links, returns/refunds, warranty, and store policy. "
99
  "If a request is out of scope or sexual/NSFW, refuse briefly and offer support options. "
100
  "Be concise and professional."
101
  )
 
 
102
  ALLOWED_KEYWORDS = (
103
  "order","track","status","delivery","shipping","ship","eta","arrive",
104
  "refund","return","exchange","warranty","guarantee","account","billing",
105
  "address","cancel","policy","help","support","agent","human"
106
  )
107
 
 
108
  order_re = re.compile(r"#(\d{1,10})")
109
+ def extract_order(text: str):
110
+ m = order_re.search(text); return m.group(1) if m else None
111
+
112
  def handle_status(o): return f"Order #{o} is in transit and should arrive in 3–5 business days."
113
  def handle_eta(o): return f"Delivery for order #{o} typically takes 3–5 days; you can track it at https://track.example.com/{o}"
114
  def handle_track(o): return f"Track order #{o} here: https://track.example.com/{o}"
 
116
  def handle_return_policy(_=None):
117
  return ("Our return policy allows returns of unused items in original packaging within 30 days of receipt. "
118
  "Would you like me to connect you with a human agent?")
119
+ def handle_cancel(o=None):
120
+ return (f"I’ve submitted a cancellation request for order #{o}. If it has already shipped, "
121
+ "we’ll process a return/refund once it’s back. You’ll receive a confirmation email shortly.")
122
  def handle_gratitude(_=None): return "You’re welcome! Anything else I can help with?"
123
  def handle_escalation(_=None): return "I can connect you with a human agent. Would you like me to do that?"
124
 
125
+ # >>> state that must reset <<<
126
  stored_order = None
127
  pending_intent = None
128
 
129
+ # public reset hook (called from app.py)
130
+ def reset_state():
131
+ global stored_order, pending_intent
132
+ stored_order = None
133
+ pending_intent = None
134
+ # clear conversation buffer (official API) :contentReference[oaicite:3]{index=3}
135
+ try: memory.clear()
136
+ except Exception: pass
137
+ return True
138
+
139
  # ---- chat templating ---------------------------------------------------------
140
  def _lc_to_messages() -> List[Dict[str,str]]:
141
  msgs = [{"role": "system", "content": SYSTEM_PROMPT}]
142
+ hist = memory.load_memory_variables({}).get("chat_history", []) or []
143
+ for m in hist:
144
  role = "user" if getattr(m, "type", "") == "human" else "assistant"
145
  msgs.append({"role": role, "content": getattr(m, "content", "")})
146
  return msgs
 
164
  if not ui:
165
  return "How can I help with your order today?"
166
 
167
+ # If memory is empty, start clean (fresh session)
168
+ hist = memory.load_memory_variables({}).get("chat_history", []) or []
169
+ if len(hist) == 0:
170
+ stored_order = None
171
+ pending_intent = None
172
+
173
+ # 1) Safety
174
  if is_sexual_or_toxic(ui):
175
  reply = REFUSAL
176
  memory.save_context({"input": ui}, {"output": reply})
 
178
 
179
  low = ui.lower()
180
 
181
+ # 2) Quick intents
182
  if any(tok in low for tok in ["thank you","thanks","thx"]):
183
  reply = handle_gratitude()
184
  memory.save_context({"input": ui}, {"output": reply})
 
188
  memory.save_context({"input": ui}, {"output": reply})
189
  return reply
190
 
191
+ # 3) Order number FIRST
192
  new_o = extract_order(ui)
193
  if new_o:
194
  stored_order = new_o
195
+ if pending_intent in ("status","eta","track","link","cancel"):
196
+ fn = {"status": handle_status,"eta": handle_eta,"track": handle_track,
197
+ "link": handle_link,"cancel": handle_cancel}[pending_intent]
198
+ reply = fn(stored_order); pending_intent = None
199
+ memory.save_context({"input": ui}, {"output": reply}); return reply
 
 
200
 
201
+ # 4) Support-only guard (skip if pending intent or new order number)
202
  if pending_intent is None and new_o is None:
203
  if not any(k in low for k in ALLOWED_KEYWORDS) and not any(k in low for k in ("hi","hello","hey")):
204
  reply = "I’m for store support only (orders, shipping, returns, warranty, account). How can I help with those?"
205
  memory.save_context({"input": ui}, {"output": reply})
206
  return reply
207
 
208
+ # 5) Intents (added 'cancel')
209
  if any(k in low for k in ["status","where is my order","check status"]):
210
  intent = "status"
211
  elif any(k in low for k in ["how long","eta","delivery time"]):
 
214
  intent = "track"
215
  elif "tracking link" in low or "resend" in low or "link" in low:
216
  intent = "link"
217
+ elif "cancel" in low:
218
+ intent = "cancel"
219
  else:
220
  intent = "fallback"
221
 
222
+ if intent in ("status","eta","track","link","cancel"):
 
223
  if not stored_order:
224
  pending_intent = intent
225
  reply = "Sureβ€”what’s your order number (e.g., #12345)?"
226
  else:
227
+ fn = {"status": handle_status,"eta": handle_eta,"track": handle_track,
228
+ "link": handle_link,"cancel": handle_cancel}[intent]
229
  reply = fn(stored_order)
230
  memory.save_context({"input": ui}, {"output": reply})
231
  return reply
232
 
233
+ # 6) LLM fallback (on-topic) + post-check
234
  reply = _generate_reply(ui)
235
+ if is_sexual_or_toxic(reply): reply = REFUSAL
 
236
  memory.save_context({"input": ui}, {"output": reply})
237
  return reply
app.py CHANGED
@@ -2,27 +2,30 @@ import os
2
  os.environ["OMP_NUM_THREADS"] = "1"
3
 
4
  import gradio as gr
5
- from SLM_CService import chat_with_memory
6
 
7
  def respond(user_message, history):
8
  reply = chat_with_memory(user_message or "")
9
- # messages format: list of dicts with role/content
10
  history = (history or []) + [
11
  {"role":"user","content": user_message or ""},
12
  {"role":"assistant","content": reply},
13
  ]
14
  return history
15
 
 
 
 
 
16
  with gr.Blocks() as demo:
17
  gr.Markdown("# πŸ›Ž Customer Support Chatbot")
18
- chatbot = gr.Chatbot(type="messages") # Gradio 'messages' format. :contentReference[oaicite:14]{index=14}
19
  with gr.Row():
20
  user_in = gr.Textbox(placeholder="Ask about orders, tracking, returns…", scale=5)
21
  send = gr.Button("Send", variant="primary")
22
  reset = gr.Button("πŸ”„ Reset Chat")
23
  send.click(respond, [user_in, chatbot], [chatbot])
24
  user_in.submit(respond, [user_in, chatbot], [chatbot])
25
- reset.click(lambda: [], None, [chatbot])
26
 
27
  if __name__ == "__main__":
28
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
2
  os.environ["OMP_NUM_THREADS"] = "1"
3
 
4
  import gradio as gr
5
+ from SLM_CService import chat_with_memory, reset_state
6
 
7
  def respond(user_message, history):
8
  reply = chat_with_memory(user_message or "")
 
9
  history = (history or []) + [
10
  {"role":"user","content": user_message or ""},
11
  {"role":"assistant","content": reply},
12
  ]
13
  return history
14
 
15
+ def reset_chat():
16
+ reset_state() # <-- clear memory + globals
17
+ return [] # clear UI
18
+
19
  with gr.Blocks() as demo:
20
  gr.Markdown("# πŸ›Ž Customer Support Chatbot")
21
+ chatbot = gr.Chatbot(type="messages")
22
  with gr.Row():
23
  user_in = gr.Textbox(placeholder="Ask about orders, tracking, returns…", scale=5)
24
  send = gr.Button("Send", variant="primary")
25
  reset = gr.Button("πŸ”„ Reset Chat")
26
  send.click(respond, [user_in, chatbot], [chatbot])
27
  user_in.submit(respond, [user_in, chatbot], [chatbot])
28
+ reset.click(reset_chat, None, [chatbot]) # <-- real reset
29
 
30
  if __name__ == "__main__":
31
  demo.launch(server_name="0.0.0.0", server_port=7860)