BasilTh commited on
Commit
77b14f6
Β·
1 Parent(s): bbe7b0d

Deploy updated SLM customer-support chatbot

Browse files
Files changed (3) hide show
  1. SLM_CService.py +144 -71
  2. app.py +12 -7
  3. requirements.txt +3 -2
SLM_CService.py CHANGED
@@ -1,60 +1,121 @@
1
  # ─── SLM_CService.py ─────────────────────────────────────────────────────────
2
- import os
3
- # Fix for libgomp warning in Spaces
4
- os.environ["OMP_NUM_THREADS"] = "1"
5
 
6
- # 1) Unsloth must come first
7
- import unsloth
 
 
 
 
8
  import torch
9
 
10
  from transformers import AutoTokenizer, BitsAndBytesConfig, pipeline
11
  from peft import PeftModel
 
 
 
12
 
13
- # 2) Simple in-memory convo buffer
14
- # we keep alternating (user, assistant) tuples
15
- conversation_history = []
16
-
17
- # 3) Model + adapter path in your repo (copied into the Space repo root)
18
- MODEL_DIR = "ThomasBasil/bitext-qlora-tinyllama"
19
-
20
- # 4) Load tokenizer from local dir
21
- tokenizer = AutoTokenizer.from_pretrained(
22
- "ThomasBasil/bitext-qlora-tinyllama", use_fast=False
23
- )
24
- tokenizer.pad_token_id = tokenizer.eos_token_id
25
- tokenizer.padding_side = "left"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  tokenizer.truncation_side = "right"
27
 
28
- # 5) QLoRA + Unsloth load in 4-bit
 
 
29
  bnb_cfg = BitsAndBytesConfig(
30
  load_in_4bit=True,
31
  bnb_4bit_quant_type="nf4",
32
  bnb_4bit_use_double_quant=True,
33
- bnb_4bit_compute_dtype=torch.bfloat16
34
  )
35
- # 5a) Base model
36
  model = unsloth.FastLanguageModel.from_pretrained(
37
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
38
  load_in_4bit=True,
39
- quantization_config=bnb_cfg,
40
  device_map="auto",
41
- trust_remote_code=True
42
  )
43
- # 5b) Attach your LoRA adapter
44
- model = PeftModel.from_pretrained(model, "ThomasBasil/bitext-qlora-tinyllama")
45
 
46
- # 6) HF text-gen pipeline
 
 
 
47
  chat_pipe = pipeline(
48
  "text-generation",
49
  model=model,
50
  tokenizer=tokenizer,
51
  trust_remote_code=True,
52
  return_full_text=False,
53
- generate_kwargs={"max_new_tokens":128, "do_sample":True, "top_p":0.9, "temperature":0.7}
54
  )
55
 
56
- # 7) FSM helpers (your existing code unmodified)
57
- import re
 
 
 
58
  order_re = re.compile(r"#(\d{1,10})")
59
  def extract_order(text: str):
60
  m = order_re.search(text)
@@ -65,70 +126,82 @@ def handle_eta(o): return f"Delivery for order #{o} typically takes 3–5 day
65
  def handle_track(o): return f"Track order #{o} here: https://track.example.com/{o}"
66
  def handle_link(o): return f"Here’s the latest tracking link for order #{o}: https://track.example.com/{o}"
67
  def handle_return_policy(_=None):
68
- return ("Our return policy allows returns of unused items in their original packaging "
69
- "within 30 days of receipt. Would you like me to connect you with a human agent?")
70
  def handle_gratitude(_=None):
71
  return "You’re welcome! Is there anything else I can help with?"
72
  def handle_escalation(_=None):
73
  return "I’m sorry, I don’t have that information. Would you like me to connect you with a human agent?"
74
 
75
- # 8) Core chat fn
76
  stored_order = None
77
  pending_intent = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  def chat_with_memory(user_input: str) -> str:
79
  global stored_order, pending_intent
80
 
81
- # A) Save into history
82
- conversation_history.append(("User", user_input))
83
 
84
- # B) New order?
85
- new_o = extract_order(user_input)
 
 
 
 
 
 
 
 
 
 
86
  if new_o:
87
  stored_order = new_o
88
- if pending_intent in ("status","eta","track","link"):
89
- fn = {"status":handle_status,"eta":handle_eta,"track":handle_track,"link":handle_link}[pending_intent]
90
  reply = fn(stored_order)
91
  pending_intent = None
92
- conversation_history.append(("Assistant", reply))
93
  return reply
94
 
95
- ui = user_input.lower().strip()
96
-
97
- # C) Gratitude
98
- if any(tok in ui for tok in ["thank you","thanks","thx"]):
99
- reply = handle_gratitude()
100
- conversation_history.append(("Assistant", reply))
101
- return reply
102
-
103
- # D) Return policy
104
- if "return" in ui:
105
- reply = handle_return_policy()
106
- conversation_history.append(("Assistant", reply))
107
- return reply
108
-
109
- # E) Classify intent
110
- if any(k in ui for k in ["status","where is my order","check status"]):
111
- intent="status"
112
- elif any(k in ui for k in ["how long","eta","delivery time"]):
113
- intent="eta"
114
- elif any(k in ui for k in ["how can i track","track my order","where is my package"]):
115
- intent="track"
116
- elif "tracking link" in ui or "resend" in ui:
117
- intent="link"
118
  else:
119
- intent="fallback"
120
 
121
- # F) Fulfill or ask order #
122
- if intent in ("status","eta","track","link"):
123
  if not stored_order:
124
  pending_intent = intent
125
  reply = "Sureβ€”what’s your order number (e.g., #12345)?"
126
  else:
127
- fn = {"status":handle_status,"eta":handle_eta,"track":handle_track,"link":handle_link}[intent]
128
  reply = fn(stored_order)
129
- else:
130
- reply = handle_escalation()
131
 
132
- # G) Save & done
133
- conversation_history.append(("Assistant", reply))
 
 
 
134
  return reply
 
1
  # ─── SLM_CService.py ─────────────────────────────────────────────────────────
2
+ # Launch-time model setup + FSM + conversational memory for the chatbot.
 
 
3
 
4
+ import os, shutil, zipfile
5
+ os.environ["OMP_NUM_THREADS"] = "1" # quiet libgomp noise
6
+ os.environ.pop("HF_HUB_OFFLINE", None) # avoid accidental offline mode
7
+
8
+ # 1) Unsloth must be imported before transformers
9
+ import unsloth
10
  import torch
11
 
12
  from transformers import AutoTokenizer, BitsAndBytesConfig, pipeline
13
  from peft import PeftModel
14
+ from langchain.memory import ConversationBufferMemory
15
+ import gdown
16
+ import re
17
 
18
+ # ── Persistent storage (HF Spaces -> Settings -> Persistent storage) ─────────
19
+ # Docs: /data persists across Space restarts. hf docs: persistent storage. :contentReference[oaicite:0]{index=0}
20
+ PERSIST_DIR = os.environ.get("PERSIST_DIR", "/data/slm_assets")
21
+ ADAPTER_DIR = os.path.join(PERSIST_DIR, "adapter")
22
+ TOKENIZER_DIR = os.path.join(PERSIST_DIR, "tokenizer")
23
+ ZIP_PATH = os.path.join(PERSIST_DIR, "assets.zip")
24
+
25
+ # ── Provide Google Drive IDs as Secrets (HF Space -> Settings -> Variables) ──
26
+ # Either one zip with both folders...
27
+ GDRIVE_ZIP_ID = os.environ.get("GDRIVE_ZIP_ID")
28
+ # ...or separate zips/files for each:
29
+ GDRIVE_ADAPTER_ID = os.environ.get("GDRIVE_ADAPTER_ID")
30
+ GDRIVE_TOKENIZER_ID = os.environ.get("GDRIVE_TOKENIZER_ID")
31
+
32
+ def _ensure_dirs():
33
+ os.makedirs(PERSIST_DIR, exist_ok=True)
34
+ os.makedirs(ADAPTER_DIR, exist_ok=True)
35
+ os.makedirs(TOKENIZER_DIR, exist_ok=True)
36
+
37
+ def _have_local_assets():
38
+ # minimal sanity checks for typical PEFT/tokenizer files
39
+ tok_ok = any(os.path.exists(os.path.join(TOKENIZER_DIR, f))
40
+ for f in ("tokenizer.json", "tokenizer.model", "tokenizer_config.json"))
41
+ lora_ok = any(os.path.exists(os.path.join(ADAPTER_DIR, f))
42
+ for f in ("adapter_config.json", "adapter_model.bin", "adapter_model.safetensors"))
43
+ return tok_ok and lora_ok
44
+
45
+ def _download_from_drive():
46
+ """Download adapter/tokenizer from Google Drive into /data using gdown."""
47
+ _ensure_dirs()
48
+ if GDRIVE_ZIP_ID:
49
+ gdown.download(id=GDRIVE_ZIP_ID, output=ZIP_PATH, quiet=False) # gdown is built for Drive. :contentReference[oaicite:1]{index=1}
50
+ with zipfile.ZipFile(ZIP_PATH, "r") as zf:
51
+ zf.extractall(PERSIST_DIR)
52
+ return
53
+
54
+ if GDRIVE_ADAPTER_ID:
55
+ ad_zip = os.path.join(PERSIST_DIR, "adapter.zip")
56
+ gdown.download(id=GDRIVE_ADAPTER_ID, output=ad_zip, quiet=False)
57
+ try:
58
+ with zipfile.ZipFile(ad_zip, "r") as zf:
59
+ zf.extractall(ADAPTER_DIR)
60
+ except zipfile.BadZipFile:
61
+ # not a zip – assume single file
62
+ shutil.move(ad_zip, os.path.join(ADAPTER_DIR, "adapter_model.bin"))
63
+
64
+ if GDRIVE_TOKENIZER_ID:
65
+ tk_zip = os.path.join(PERSIST_DIR, "tokenizer.zip")
66
+ gdown.download(id=GDRIVE_TOKENIZER_ID, output=tk_zip, quiet=False)
67
+ try:
68
+ with zipfile.ZipFile(tk_zip, "r") as zf:
69
+ zf.extractall(TOKENIZER_DIR)
70
+ except zipfile.BadZipFile:
71
+ shutil.move(tk_zip, os.path.join(TOKENIZER_DIR, "tokenizer.json"))
72
+
73
+ # ── Ensure local assets from Drive (first launch will download) ──────────────
74
+ if not _have_local_assets():
75
+ _download_from_drive() # persists in /data if you enabled it. :contentReference[oaicite:2]{index=2}
76
+
77
+ # ── Tokenizer (from your Drive-backed folder) ────────────────────────────────
78
+ tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR, use_fast=False)
79
+ tokenizer.pad_token_id = tokenizer.eos_token_id
80
+ tokenizer.padding_side = "left"
81
  tokenizer.truncation_side = "right"
82
 
83
+ # ── Base model (4-bit) via Unsloth + your PEFT adapter ──────────────────────
84
+ BASE = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
85
+
86
  bnb_cfg = BitsAndBytesConfig(
87
  load_in_4bit=True,
88
  bnb_4bit_quant_type="nf4",
89
  bnb_4bit_use_double_quant=True,
90
+ bnb_4bit_compute_dtype=torch.bfloat16,
91
  )
92
+
93
  model = unsloth.FastLanguageModel.from_pretrained(
94
+ BASE,
95
  load_in_4bit=True,
96
+ quantization_config=bnb_cfg, # prefer quantization_config over quant_type
97
  device_map="auto",
98
+ trust_remote_code=True,
99
  )
 
 
100
 
101
+ model = PeftModel.from_pretrained(model, ADAPTER_DIR)
102
+
103
+ # ── Text-generation pipeline (use generate_kwargs, not generation_kwargs) ────
104
+ # Transformers pipelines accept `generate_kwargs` to forward to .generate(). :contentReference[oaicite:3]{index=3}
105
  chat_pipe = pipeline(
106
  "text-generation",
107
  model=model,
108
  tokenizer=tokenizer,
109
  trust_remote_code=True,
110
  return_full_text=False,
111
+ generate_kwargs={"max_new_tokens": 128, "do_sample": True, "top_p": 0.9, "temperature": 0.7},
112
  )
113
 
114
+ # ── Conversational memory (LangChain) ────────────────────────────────────────
115
+ # ConversationBufferMemory stores full turn-by-turn chat history. :contentReference[oaicite:4]{index=4}
116
+ memory = ConversationBufferMemory(return_messages=True)
117
+
118
+ # ── FSM helpers (your original logic, kept intact) ───────────────────────────
119
  order_re = re.compile(r"#(\d{1,10})")
120
  def extract_order(text: str):
121
  m = order_re.search(text)
 
126
  def handle_track(o): return f"Track order #{o} here: https://track.example.com/{o}"
127
  def handle_link(o): return f"Here’s the latest tracking link for order #{o}: https://track.example.com/{o}"
128
  def handle_return_policy(_=None):
129
+ return ("Our return policy allows returns of unused items in their original packaging within 30 days of receipt. "
130
+ "Would you like me to connect you with a human agent?")
131
  def handle_gratitude(_=None):
132
  return "You’re welcome! Is there anything else I can help with?"
133
  def handle_escalation(_=None):
134
  return "I’m sorry, I don’t have that information. Would you like me to connect you with a human agent?"
135
 
 
136
  stored_order = None
137
  pending_intent = None
138
+
139
+ def _history_to_prompt(user_input: str) -> str:
140
+ """Build a prompt from LangChain memory turns for fallback generation."""
141
+ hist = memory.load_memory_variables({}).get("chat_history", [])
142
+ prompt = "You are a helpful support assistant.\n"
143
+ for msg in hist:
144
+ # LangChain messages expose a .type like 'human'/'ai' in many versions
145
+ mtype = getattr(msg, "type", "")
146
+ role = "User" if mtype == "human" else "Assistant"
147
+ content = getattr(msg, "content", "")
148
+ prompt += f"{role}: {content}\n"
149
+ prompt += f"User: {user_input}\nAssistant: "
150
+ return prompt
151
+
152
  def chat_with_memory(user_input: str) -> str:
153
  global stored_order, pending_intent
154
 
155
+ ui = user_input.strip()
156
+ low = ui.lower()
157
 
158
+ # A) quick intent short-circuits
159
+ if any(tok in low for tok in ["thank you", "thanks", "thx"]):
160
+ reply = handle_gratitude()
161
+ memory.save_context({"input": ui}, {"output": reply})
162
+ return reply
163
+ if "return" in low:
164
+ reply = handle_return_policy()
165
+ memory.save_context({"input": ui}, {"output": reply})
166
+ return reply
167
+
168
+ # B) order number?
169
+ new_o = extract_order(ui)
170
  if new_o:
171
  stored_order = new_o
172
+ if pending_intent in ("status", "eta", "track", "link"):
173
+ fn = {"status": handle_status, "eta": handle_eta, "track": handle_track, "link": handle_link}[pending_intent]
174
  reply = fn(stored_order)
175
  pending_intent = None
176
+ memory.save_context({"input": ui}, {"output": reply})
177
  return reply
178
 
179
+ # C) intent classification
180
+ if any(k in low for k in ["status", "where is my order", "check status"]):
181
+ intent = "status"
182
+ elif any(k in low for k in ["how long", "eta", "delivery time"]):
183
+ intent = "eta"
184
+ elif any(k in low for k in ["how can i track", "track my order", "where is my package"]):
185
+ intent = "track"
186
+ elif "tracking link" in low or "resend" in low:
187
+ intent = "link"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
  else:
189
+ intent = "fallback"
190
 
191
+ # D) handle core intents (ask for order first if needed)
192
+ if intent in ("status", "eta", "track", "link"):
193
  if not stored_order:
194
  pending_intent = intent
195
  reply = "Sureβ€”what’s your order number (e.g., #12345)?"
196
  else:
197
+ fn = {"status": handle_status, "eta": handle_eta, "track": handle_track, "link": handle_link}[intent]
198
  reply = fn(stored_order)
199
+ memory.save_context({"input": ui}, {"output": reply})
200
+ return reply
201
 
202
+ # E) fallback β†’ generate with chat history context
203
+ prompt = _history_to_prompt(ui)
204
+ out = chat_pipe(prompt)[0]["generated_text"]
205
+ reply = out.split("Assistant:")[-1].strip()
206
+ memory.save_context({"input": ui}, {"output": reply})
207
  return reply
app.py CHANGED
@@ -1,23 +1,28 @@
1
  import os
2
- os.environ["OMP_NUM_THREADS"] = "1" # Silence Gradio startup warning
3
 
4
  import gradio as gr
5
  from SLM_CService import chat_with_memory
6
 
7
  def respond(user_message, history):
 
 
8
  bot_reply = chat_with_memory(user_message)
9
- history = history + [(user_message, bot_reply)]
10
  return history, history
11
 
12
  with gr.Blocks() as demo:
13
  gr.Markdown("# πŸ›Ž Customer Support Chatbot")
14
- chatbot = gr.Chatbot() # Replaces ChatInterface/FileMessage/TextMessage :contentReference[oaicite:8]{index=8}
15
  with gr.Row():
16
- user_in = gr.Textbox(placeholder="Type your message here...")
17
- submit = gr.Button("Send")
18
- reset = gr.Button("πŸ”„ Reset Chat")
19
- submit.click(respond, [user_in, chatbot], [chatbot, chatbot])
 
20
  reset.click(lambda: ([], []), None, [chatbot, chatbot])
 
 
21
 
22
  if __name__ == "__main__":
23
  demo.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import os
2
+ os.environ["OMP_NUM_THREADS"] = "1" # silence OpenMP spam
3
 
4
  import gradio as gr
5
  from SLM_CService import chat_with_memory
6
 
7
  def respond(user_message, history):
8
+ if not user_message:
9
+ return history, history
10
  bot_reply = chat_with_memory(user_message)
11
+ history = (history or []) + [(user_message, bot_reply)]
12
  return history, history
13
 
14
  with gr.Blocks() as demo:
15
  gr.Markdown("# πŸ›Ž Customer Support Chatbot")
16
+ chatbot = gr.Chatbot()
17
  with gr.Row():
18
+ user_in = gr.Textbox(placeholder="Type your message here...", scale=5)
19
+ send = gr.Button("Send", variant="primary")
20
+ reset = gr.Button("πŸ”„ Reset Chat")
21
+
22
+ send.click(respond, [user_in, chatbot], [chatbot, chatbot])
23
  reset.click(lambda: ([], []), None, [chatbot, chatbot])
24
+ # Optional: submit on enter
25
+ user_in.submit(respond, [user_in, chatbot], [chatbot, chatbot])
26
 
27
  if __name__ == "__main__":
28
  demo.launch(server_name="0.0.0.0", server_port=7860)
requirements.txt CHANGED
@@ -1,11 +1,12 @@
1
- gradio==5.41.1 # Matches Spaces SDK version :contentReference[oaicite:9]{index=9}
2
  transformers
3
  torch
4
  sentencepiece
5
- langchain # Required for ConversationBufferMemory :contentReference[oaicite:10]{index=10}
6
  bitsandbytes
7
  peft
8
  xformers
9
  unsloth
10
  unsloth_zoo
11
  huggingface_hub
 
 
1
+ gradio==5.41.1
2
  transformers
3
  torch
4
  sentencepiece
5
+ langchain
6
  bitsandbytes
7
  peft
8
  xformers
9
  unsloth
10
  unsloth_zoo
11
  huggingface_hub
12
+ gdown