File size: 9,265 Bytes
f3b040f
77b14f6
f3b040f
77b14f6
 
 
 
 
 
f3b040f
 
 
 
77b14f6
 
 
f3b040f
77b14f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85c8b2b
 
77b14f6
 
 
f3b040f
 
 
 
77b14f6
85c8b2b
77b14f6
f3b040f
77b14f6
f3b040f
77b14f6
f3b040f
77b14f6
f3b040f
85c8b2b
77b14f6
 
 
 
85c8b2b
 
 
 
 
 
77b14f6
816e617
 
77b14f6
 
 
 
 
816e617
 
 
 
 
f3b040f
 
85c8b2b
f3b040f
 
77b14f6
 
f3b040f
 
 
 
816e617
f3b040f
 
77b14f6
 
 
 
 
 
 
 
 
 
 
 
 
 
816e617
 
85c8b2b
77b14f6
 
f3b040f
77b14f6
 
 
 
 
 
 
 
 
 
 
 
816e617
 
77b14f6
 
85c8b2b
816e617
77b14f6
816e617
 
77b14f6
 
 
 
 
 
 
 
 
816e617
77b14f6
f3b040f
77b14f6
 
f3b040f
 
 
816e617
77b14f6
f3b040f
77b14f6
 
85c8b2b
77b14f6
 
 
 
 
816e617
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# ─── SLM_CService.py ─────────────────────────────────────────────────────────
# Launch-time model setup + FSM + conversational memory for the chatbot.

import os, shutil, zipfile
os.environ["OMP_NUM_THREADS"] = "1"          # quiet libgomp noise
os.environ.pop("HF_HUB_OFFLINE", None)       # avoid accidental offline mode

# 1) Unsloth must be imported before transformers
import unsloth
import torch

from transformers import AutoTokenizer, BitsAndBytesConfig, pipeline
from peft import PeftModel
from langchain.memory import ConversationBufferMemory
import gdown
import re

# ── Persistent storage (HF Spaces -> Settings -> Persistent storage) ─────────
# Docs: /data persists across Space restarts. hf docs: persistent storage. :contentReference[oaicite:0]{index=0}
PERSIST_DIR   = os.environ.get("PERSIST_DIR", "/data/slm_assets")
ADAPTER_DIR   = os.path.join(PERSIST_DIR, "adapter")
TOKENIZER_DIR = os.path.join(PERSIST_DIR, "tokenizer")
ZIP_PATH      = os.path.join(PERSIST_DIR, "assets.zip")

# ── Provide Google Drive IDs as Secrets (HF Space -> Settings -> Variables) ──
# Either one zip with both folders...
GDRIVE_ZIP_ID       = os.environ.get("GDRIVE_ZIP_ID")
# ...or separate zips/files for each:
GDRIVE_ADAPTER_ID   = os.environ.get("GDRIVE_ADAPTER_ID")
GDRIVE_TOKENIZER_ID = os.environ.get("GDRIVE_TOKENIZER_ID")

def _ensure_dirs():
    os.makedirs(PERSIST_DIR, exist_ok=True)
    os.makedirs(ADAPTER_DIR, exist_ok=True)
    os.makedirs(TOKENIZER_DIR, exist_ok=True)

def _have_local_assets():
    # minimal sanity checks for typical PEFT/tokenizer files
    tok_ok = any(os.path.exists(os.path.join(TOKENIZER_DIR, f))
                 for f in ("tokenizer.json", "tokenizer.model", "tokenizer_config.json"))
    lora_ok = any(os.path.exists(os.path.join(ADAPTER_DIR, f))
                  for f in ("adapter_config.json", "adapter_model.bin", "adapter_model.safetensors"))
    return tok_ok and lora_ok

def _download_from_drive():
    """Download adapter/tokenizer from Google Drive into /data using gdown."""
    _ensure_dirs()
    if GDRIVE_ZIP_ID:
        gdown.download(id=GDRIVE_ZIP_ID, output=ZIP_PATH, quiet=False)  # gdown is built for Drive. :contentReference[oaicite:1]{index=1}
        with zipfile.ZipFile(ZIP_PATH, "r") as zf:
            zf.extractall(PERSIST_DIR)
        return

    if GDRIVE_ADAPTER_ID:
        ad_zip = os.path.join(PERSIST_DIR, "adapter.zip")
        gdown.download(id=GDRIVE_ADAPTER_ID, output=ad_zip, quiet=False)
        try:
            with zipfile.ZipFile(ad_zip, "r") as zf:
                zf.extractall(ADAPTER_DIR)
        except zipfile.BadZipFile:
            # not a zip – assume single file
            shutil.move(ad_zip, os.path.join(ADAPTER_DIR, "adapter_model.bin"))

    if GDRIVE_TOKENIZER_ID:
        tk_zip = os.path.join(PERSIST_DIR, "tokenizer.zip")
        gdown.download(id=GDRIVE_TOKENIZER_ID, output=tk_zip, quiet=False)
        try:
            with zipfile.ZipFile(tk_zip, "r") as zf:
                zf.extractall(TOKENIZER_DIR)
        except zipfile.BadZipFile:
            shutil.move(tk_zip, os.path.join(TOKENIZER_DIR, "tokenizer.json"))

# ── Ensure local assets from Drive (first launch will download) ──────────────
if not _have_local_assets():
    _download_from_drive()     # persists in /data if you enabled it. :contentReference[oaicite:2]{index=2}

# ── Tokenizer (from your Drive-backed folder) ────────────────────────────────
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR, use_fast=False)
tokenizer.pad_token_id    = tokenizer.eos_token_id
tokenizer.padding_side    = "left"
tokenizer.truncation_side = "right"

# ── Base model (4-bit) via Unsloth + your PEFT adapter ──────────────────────
BASE = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"

bnb_cfg = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = unsloth.FastLanguageModel.from_pretrained(
    BASE,
    load_in_4bit=True,
    quantization_config=bnb_cfg,     # prefer quantization_config over quant_type
    device_map="auto",
    trust_remote_code=True,
)

model = PeftModel.from_pretrained(model, ADAPTER_DIR)

# ── Text-generation pipeline (use generate_kwargs, not generation_kwargs) ────
# Transformers pipelines accept `generate_kwargs` to forward to .generate(). :contentReference[oaicite:3]{index=3}
chat_pipe = pipeline(
    "text-generation",
    model=model,
    tokenizer=tokenizer,
    trust_remote_code=True,
    return_full_text=False,
    generate_kwargs={"max_new_tokens": 128, "do_sample": True, "top_p": 0.9, "temperature": 0.7},
)

# ── Conversational memory (LangChain) ────────────────────────────────────────
# ConversationBufferMemory stores full turn-by-turn chat history. :contentReference[oaicite:4]{index=4}
memory = ConversationBufferMemory(return_messages=True)

# ── FSM helpers (your original logic, kept intact) ───────────────────────────
order_re = re.compile(r"#(\d{1,10})")
def extract_order(text: str):
    m = order_re.search(text)
    return m.group(1) if m else None

def handle_status(o): return f"Order #{o} is in transit and should arrive in 3–5 business days."
def handle_eta(o):    return f"Delivery for order #{o} typically takes 3–5 days; you can track it at https://track.example.com/{o}"
def handle_track(o):  return f"Track order #{o} here: https://track.example.com/{o}"
def handle_link(o):   return f"Here’s the latest tracking link for order #{o}: https://track.example.com/{o}"
def handle_return_policy(_=None):
    return ("Our return policy allows returns of unused items in their original packaging within 30 days of receipt. "
            "Would you like me to connect you with a human agent?")
def handle_gratitude(_=None):
    return "You’re welcome! Is there anything else I can help with?"
def handle_escalation(_=None):
    return "I’m sorry, I don’t have that information. Would you like me to connect you with a human agent?"

stored_order   = None
pending_intent = None

def _history_to_prompt(user_input: str) -> str:
    """Build a prompt from LangChain memory turns for fallback generation."""
    hist = memory.load_memory_variables({}).get("chat_history", [])
    prompt = "You are a helpful support assistant.\n"
    for msg in hist:
        # LangChain messages expose a .type like 'human'/'ai' in many versions
        mtype = getattr(msg, "type", "")
        role  = "User" if mtype == "human" else "Assistant"
        content = getattr(msg, "content", "")
        prompt += f"{role}: {content}\n"
    prompt += f"User: {user_input}\nAssistant: "
    return prompt

def chat_with_memory(user_input: str) -> str:
    global stored_order, pending_intent

    ui = user_input.strip()
    low = ui.lower()

    # A) quick intent short-circuits
    if any(tok in low for tok in ["thank you", "thanks", "thx"]):
        reply = handle_gratitude()
        memory.save_context({"input": ui}, {"output": reply})
        return reply
    if "return" in low:
        reply = handle_return_policy()
        memory.save_context({"input": ui}, {"output": reply})
        return reply

    # B) order number?
    new_o = extract_order(ui)
    if new_o:
        stored_order = new_o
        if pending_intent in ("status", "eta", "track", "link"):
            fn = {"status": handle_status, "eta": handle_eta, "track": handle_track, "link": handle_link}[pending_intent]
            reply = fn(stored_order)
            pending_intent = None
            memory.save_context({"input": ui}, {"output": reply})
            return reply

    # C) intent classification
    if any(k in low for k in ["status", "where is my order", "check status"]):
        intent = "status"
    elif any(k in low for k in ["how long", "eta", "delivery time"]):
        intent = "eta"
    elif any(k in low for k in ["how can i track", "track my order", "where is my package"]):
        intent = "track"
    elif "tracking link" in low or "resend" in low:
        intent = "link"
    else:
        intent = "fallback"

    # D) handle core intents (ask for order first if needed)
    if intent in ("status", "eta", "track", "link"):
        if not stored_order:
            pending_intent = intent
            reply = "Sureβ€”what’s your order number (e.g., #12345)?"
        else:
            fn = {"status": handle_status, "eta": handle_eta, "track": handle_track, "link": handle_link}[intent]
            reply = fn(stored_order)
        memory.save_context({"input": ui}, {"output": reply})
        return reply

    # E) fallback β†’ generate with chat history context
    prompt = _history_to_prompt(ui)
    out = chat_pipe(prompt)[0]["generated_text"]
    reply = out.split("Assistant:")[-1].strip()
    memory.save_context({"input": ui}, {"output": reply})
    return reply