|
|
|
|
|
|
|
import os, shutil, zipfile |
|
os.environ["OMP_NUM_THREADS"] = "1" |
|
os.environ.pop("HF_HUB_OFFLINE", None) |
|
|
|
|
|
import unsloth |
|
import torch |
|
|
|
from transformers import AutoTokenizer, BitsAndBytesConfig, pipeline |
|
from peft import PeftModel |
|
from langchain.memory import ConversationBufferMemory |
|
import gdown |
|
import re |
|
|
|
|
|
|
|
PERSIST_DIR = os.environ.get("PERSIST_DIR", "/data/slm_assets") |
|
ADAPTER_DIR = os.path.join(PERSIST_DIR, "adapter") |
|
TOKENIZER_DIR = os.path.join(PERSIST_DIR, "tokenizer") |
|
ZIP_PATH = os.path.join(PERSIST_DIR, "assets.zip") |
|
|
|
|
|
|
|
GDRIVE_ZIP_ID = os.environ.get("GDRIVE_ZIP_ID") |
|
|
|
GDRIVE_ADAPTER_ID = os.environ.get("GDRIVE_ADAPTER_ID") |
|
GDRIVE_TOKENIZER_ID = os.environ.get("GDRIVE_TOKENIZER_ID") |
|
|
|
def _ensure_dirs(): |
|
os.makedirs(PERSIST_DIR, exist_ok=True) |
|
os.makedirs(ADAPTER_DIR, exist_ok=True) |
|
os.makedirs(TOKENIZER_DIR, exist_ok=True) |
|
|
|
def _have_local_assets(): |
|
|
|
tok_ok = any(os.path.exists(os.path.join(TOKENIZER_DIR, f)) |
|
for f in ("tokenizer.json", "tokenizer.model", "tokenizer_config.json")) |
|
lora_ok = any(os.path.exists(os.path.join(ADAPTER_DIR, f)) |
|
for f in ("adapter_config.json", "adapter_model.bin", "adapter_model.safetensors")) |
|
return tok_ok and lora_ok |
|
|
|
def _download_from_drive(): |
|
"""Download adapter/tokenizer from Google Drive into /data using gdown.""" |
|
_ensure_dirs() |
|
if GDRIVE_ZIP_ID: |
|
gdown.download(id=GDRIVE_ZIP_ID, output=ZIP_PATH, quiet=False) |
|
with zipfile.ZipFile(ZIP_PATH, "r") as zf: |
|
zf.extractall(PERSIST_DIR) |
|
return |
|
|
|
if GDRIVE_ADAPTER_ID: |
|
ad_zip = os.path.join(PERSIST_DIR, "adapter.zip") |
|
gdown.download(id=GDRIVE_ADAPTER_ID, output=ad_zip, quiet=False) |
|
try: |
|
with zipfile.ZipFile(ad_zip, "r") as zf: |
|
zf.extractall(ADAPTER_DIR) |
|
except zipfile.BadZipFile: |
|
|
|
shutil.move(ad_zip, os.path.join(ADAPTER_DIR, "adapter_model.bin")) |
|
|
|
if GDRIVE_TOKENIZER_ID: |
|
tk_zip = os.path.join(PERSIST_DIR, "tokenizer.zip") |
|
gdown.download(id=GDRIVE_TOKENIZER_ID, output=tk_zip, quiet=False) |
|
try: |
|
with zipfile.ZipFile(tk_zip, "r") as zf: |
|
zf.extractall(TOKENIZER_DIR) |
|
except zipfile.BadZipFile: |
|
shutil.move(tk_zip, os.path.join(TOKENIZER_DIR, "tokenizer.json")) |
|
|
|
|
|
if not _have_local_assets(): |
|
_download_from_drive() |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(TOKENIZER_DIR, use_fast=False) |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
tokenizer.padding_side = "left" |
|
tokenizer.truncation_side = "right" |
|
|
|
|
|
BASE = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
|
|
|
bnb_cfg = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_compute_dtype=torch.bfloat16, |
|
) |
|
|
|
model = unsloth.FastLanguageModel.from_pretrained( |
|
BASE, |
|
load_in_4bit=True, |
|
quantization_config=bnb_cfg, |
|
device_map="auto", |
|
trust_remote_code=True, |
|
) |
|
|
|
model = PeftModel.from_pretrained(model, ADAPTER_DIR) |
|
|
|
|
|
|
|
chat_pipe = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
trust_remote_code=True, |
|
return_full_text=False, |
|
generate_kwargs={"max_new_tokens": 128, "do_sample": True, "top_p": 0.9, "temperature": 0.7}, |
|
) |
|
|
|
|
|
|
|
memory = ConversationBufferMemory(return_messages=True) |
|
|
|
|
|
order_re = re.compile(r"#(\d{1,10})") |
|
def extract_order(text: str): |
|
m = order_re.search(text) |
|
return m.group(1) if m else None |
|
|
|
def handle_status(o): return f"Order #{o} is in transit and should arrive in 3β5 business days." |
|
def handle_eta(o): return f"Delivery for order #{o} typically takes 3β5 days; you can track it at https://track.example.com/{o}" |
|
def handle_track(o): return f"Track order #{o} here: https://track.example.com/{o}" |
|
def handle_link(o): return f"Hereβs the latest tracking link for order #{o}: https://track.example.com/{o}" |
|
def handle_return_policy(_=None): |
|
return ("Our return policy allows returns of unused items in their original packaging within 30 days of receipt. " |
|
"Would you like me to connect you with a human agent?") |
|
def handle_gratitude(_=None): |
|
return "Youβre welcome! Is there anything else I can help with?" |
|
def handle_escalation(_=None): |
|
return "Iβm sorry, I donβt have that information. Would you like me to connect you with a human agent?" |
|
|
|
stored_order = None |
|
pending_intent = None |
|
|
|
def _history_to_prompt(user_input: str) -> str: |
|
"""Build a prompt from LangChain memory turns for fallback generation.""" |
|
hist = memory.load_memory_variables({}).get("chat_history", []) |
|
prompt = "You are a helpful support assistant.\n" |
|
for msg in hist: |
|
|
|
mtype = getattr(msg, "type", "") |
|
role = "User" if mtype == "human" else "Assistant" |
|
content = getattr(msg, "content", "") |
|
prompt += f"{role}: {content}\n" |
|
prompt += f"User: {user_input}\nAssistant: " |
|
return prompt |
|
|
|
def chat_with_memory(user_input: str) -> str: |
|
global stored_order, pending_intent |
|
|
|
ui = user_input.strip() |
|
low = ui.lower() |
|
|
|
|
|
if any(tok in low for tok in ["thank you", "thanks", "thx"]): |
|
reply = handle_gratitude() |
|
memory.save_context({"input": ui}, {"output": reply}) |
|
return reply |
|
if "return" in low: |
|
reply = handle_return_policy() |
|
memory.save_context({"input": ui}, {"output": reply}) |
|
return reply |
|
|
|
|
|
new_o = extract_order(ui) |
|
if new_o: |
|
stored_order = new_o |
|
if pending_intent in ("status", "eta", "track", "link"): |
|
fn = {"status": handle_status, "eta": handle_eta, "track": handle_track, "link": handle_link}[pending_intent] |
|
reply = fn(stored_order) |
|
pending_intent = None |
|
memory.save_context({"input": ui}, {"output": reply}) |
|
return reply |
|
|
|
|
|
if any(k in low for k in ["status", "where is my order", "check status"]): |
|
intent = "status" |
|
elif any(k in low for k in ["how long", "eta", "delivery time"]): |
|
intent = "eta" |
|
elif any(k in low for k in ["how can i track", "track my order", "where is my package"]): |
|
intent = "track" |
|
elif "tracking link" in low or "resend" in low: |
|
intent = "link" |
|
else: |
|
intent = "fallback" |
|
|
|
|
|
if intent in ("status", "eta", "track", "link"): |
|
if not stored_order: |
|
pending_intent = intent |
|
reply = "Sureβwhatβs your order number (e.g., #12345)?" |
|
else: |
|
fn = {"status": handle_status, "eta": handle_eta, "track": handle_track, "link": handle_link}[intent] |
|
reply = fn(stored_order) |
|
memory.save_context({"input": ui}, {"output": reply}) |
|
return reply |
|
|
|
|
|
prompt = _history_to_prompt(ui) |
|
out = chat_pipe(prompt)[0]["generated_text"] |
|
reply = out.split("Assistant:")[-1].strip() |
|
memory.save_context({"input": ui}, {"output": reply}) |
|
return reply |
|
|