slm-customer-support-chatbot / SLM_CService.py
BasilTh
Deploy updated SLM customer-support chatbot
7d9bb79
raw
history blame
8.98 kB
# ── SLM_CService.py ───────────────────────────────────────────────────────────
# Model load + FSM + conversational memory for your Gradio Space.
import os
import re
# Keep OpenMP quiet in Spaces logs
os.environ["OMP_NUM_THREADS"] = "1"
# Ensure we don't accidentally run offline
os.environ.pop("HF_HUB_OFFLINE", None)
# 1) Unsloth must be imported BEFORE transformers/peft to apply optimizations.
# (Otherwise you may see perf/memory warnings.)
# Ref: Unsloth team warning in issues.
import unsloth # noqa: E402 # must be before transformers/peft :contentReference[oaicite:2]{index=2}
import torch
from transformers import AutoTokenizer, BitsAndBytesConfig, pipeline
from peft import PeftModel
from langchain.memory import ConversationBufferMemory
# ──────────────────────────────────────────────────────────────────────────────
# Your Hub repo that contains the tokenizer + PEFT adapter files
REPO = "ThomasBasil/bitext-qlora-tinyllama"
BASE = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# If your files are nested, set this to the exact subfolder path (or use
# the HF_SUBFOLDER env var from Space β†’ Settings β†’ Variables).
# Example from your screenshot:
DEFAULT_SUBFOLDER = "bitext-qlora-tinyllama-20250807T224217Z-1-001/bitext-qlora-tinyllama"
SUBFOLDER = os.environ.get("HF_SUBFOLDER", DEFAULT_SUBFOLDER)
# 4-bit NF4 quantization config (QLoRA-style)
# Ref: Transformers bitsandbytes quantization docs. :contentReference[oaicite:3]{index=3}
bnb_cfg = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
)
# ---- Robust helpers to load from root or subfolder ---------------------------
def _load_tokenizer(repo_id: str):
"""
Try to load tokenizer from repo root; if missing, try configured subfolder.
Transformers supports `subfolder` in from_pretrained for tokenizers. :contentReference[oaicite:4]{index=4}
"""
# Try at repo root first
try:
tok = AutoTokenizer.from_pretrained(repo_id, use_fast=False)
except Exception:
# Try "tokenizer" subdir at root
try:
tok = AutoTokenizer.from_pretrained(repo_id, subfolder="tokenizer", use_fast=False)
except Exception:
# Try the provided nested path
tok = AutoTokenizer.from_pretrained(repo_id, subfolder=SUBFOLDER, use_fast=False)
# sensible defaults for causal LM
if tok.pad_token_id is None and tok.eos_token_id is not None:
tok.pad_token_id = tok.eos_token_id
tok.padding_side = "left"
tok.truncation_side = "right"
return tok
def _attach_adapter(base_model, repo_id: str):
"""
Attach PEFT adapter from root; if not found, try subfolder variants.
(PEFT supports kwargs like `subfolder`, though older versions had quirks;
if you ever hit issues, place adapter files at repo root.) :contentReference[oaicite:5]{index=5}
"""
# Try repo root
try:
return PeftModel.from_pretrained(base_model, repo_id)
except Exception:
# Try 'adapter' subdir at root
try:
return PeftModel.from_pretrained(base_model, repo_id, subfolder="adapter")
except Exception:
# Try the provided nested path
return PeftModel.from_pretrained(base_model, repo_id, subfolder=SUBFOLDER)
# ---- Load tokenizer, base model (4-bit), and attach adapter ------------------
tokenizer = _load_tokenizer(REPO)
model = unsloth.FastLanguageModel.from_pretrained(
BASE,
load_in_4bit=True,
quantization_config=bnb_cfg, # prefer quantization_config over legacy args
device_map="auto",
trust_remote_code=True,
)
model = _attach_adapter(model, REPO)
model.eval()
# Transformers pipeline accepts `generate_kwargs` to pass through to .generate().
# Ref: Pipelines docs mention `generate_kwargs`. :contentReference[oaicite:6]{index=6}
chat_pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
trust_remote_code=True,
return_full_text=False,
generate_kwargs={
"max_new_tokens": 128,
"do_sample": True,
"top_p": 0.9,
"temperature": 0.7,
},
)
# ──────────────────────────────────────────────────────────────────────────────
# Conversational Memory (LangChain)
memory = ConversationBufferMemory(return_messages=True)
# ──────────────────────────────────────────────────────────────────────────────
# Simple FSM helpers
order_re = re.compile(r"#(\d{1,10})")
def extract_order(text: str):
m = order_re.search(text)
return m.group(1) if m else None
def handle_status(o): return f"Order #{o} is in transit and should arrive in 3–5 business days."
def handle_eta(o): return f"Delivery for order #{o} typically takes 3–5 days; you can track it at https://track.example.com/{o}"
def handle_track(o): return f"Track order #{o} here: https://track.example.com/{o}"
def handle_link(o): return f"Here’s the latest tracking link for order #{o}: https://track.example.com/{o}"
def handle_return_policy(_=None):
return ("Our return policy allows returns of unused items in their original packaging within 30 days of receipt. "
"Would you like me to connect you with a human agent?")
def handle_gratitude(_=None):
return "You’re welcome! Is there anything else I can help with?"
def handle_escalation(_=None):
return "I’m sorry, I don’t have that information. Would you like me to connect you with a human agent?"
stored_order = None
pending_intent = None
def _history_to_prompt(user_input: str) -> str:
"""Build a plain-text prompt that includes chat history for fallback generation."""
hist = memory.load_memory_variables({}).get("chat_history", [])
prompt = "You are a helpful support assistant.\n"
for msg in hist:
# LangChain messages often have .type ('human'/'ai') and .content
mtype = getattr(msg, "type", "")
role = "User" if mtype == "human" else "Assistant"
content = getattr(msg, "content", "")
prompt += f"{role}: {content}\n"
prompt += f"User: {user_input}\nAssistant: "
return prompt
def chat_with_memory(user_input: str) -> str:
"""Main entrypoint called by app.py."""
global stored_order, pending_intent
ui = (user_input or "").strip()
low = ui.lower()
# A) quick intent short-circuits
if any(tok in low for tok in ["thank you", "thanks", "thx"]):
reply = handle_gratitude()
memory.save_context({"input": ui}, {"output": reply})
return reply
if "return" in low:
reply = handle_return_policy()
memory.save_context({"input": ui}, {"output": reply})
return reply
# B) order number?
new_o = extract_order(ui)
if new_o:
stored_order = new_o
if pending_intent in ("status", "eta", "track", "link"):
fn = {"status": handle_status, "eta": handle_eta, "track": handle_track, "link": handle_link}[pending_intent]
reply = fn(stored_order)
pending_intent = None
memory.save_context({"input": ui}, {"output": reply})
return reply
# C) intent classification
if any(k in low for k in ["status", "where is my order", "check status"]):
intent = "status"
elif any(k in low for k in ["how long", "eta", "delivery time"]):
intent = "eta"
elif any(k in low for k in ["how can i track", "track my order", "where is my package"]):
intent = "track"
elif "tracking link" in low or "resend" in low:
intent = "link"
else:
intent = "fallback"
# D) handle core intents (ask for order first if needed)
if intent in ("status", "eta", "track", "link"):
if not stored_order:
pending_intent = intent
reply = "Sureβ€”what’s your order number (e.g., #12345)?"
else:
fn = {"status": handle_status, "eta": handle_eta, "track": handle_track, "link": handle_link}[intent]
reply = fn(stored_order)
memory.save_context({"input": ui}, {"output": reply})
return reply
# E) fallback β†’ generate with chat history context
prompt = _history_to_prompt(ui)
out = chat_pipe(prompt)[0]["generated_text"]
reply = out.split("Assistant:")[-1].strip()
memory.save_context({"input": ui}, {"output": reply})
return reply