|
import re |
|
from transformers import pipeline, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig |
|
from langchain.memory import ConversationBufferMemory |
|
import torch, unsloth, triton |
|
|
|
|
|
|
|
FINETUNED_DIR = "/content/drive/MyDrive/bitext-qlora-tinyllama" |
|
bnb_cfg = BitsAndBytesConfig( |
|
load_in_4bit=True, |
|
bnb_4bit_quant_type="nf4", |
|
bnb_4bit_use_double_quant=True, |
|
bnb_4bit_compute_dtype=torch.bfloat16 |
|
) |
|
tokenizer = AutoTokenizer.from_pretrained(FINETUNED_DIR, use_fast=False) |
|
tokenizer.pad_token_id = tokenizer.eos_token_id |
|
tokenizer.padding_side = "left" |
|
tokenizer.truncation_side = "right" |
|
model = AutoModelForCausalLM.from_pretrained( |
|
FINETUNED_DIR, |
|
quantization_config=bnb_cfg, |
|
device_map="auto", |
|
trust_remote_code=True |
|
) |
|
|
|
|
|
memory = ConversationBufferMemory(memory_key="user_lines", |
|
human_prefix="User", |
|
ai_prefix="Assistant", |
|
return_messages=False) |
|
stored_order = None |
|
pending_intent = None |
|
|
|
chat_pipe = pipeline( |
|
"text-generation", |
|
model=model, |
|
tokenizer=tokenizer, |
|
trust_remote_code=True, |
|
return_full_text=False |
|
) |
|
|
|
|
|
order_re = re.compile(r"#(\\d{1,10})") |
|
def extract_order(text): |
|
m = order_re.search(text) |
|
return m.group(1) if m else None |
|
|
|
def handle_status(o): |
|
return f"Order #{o} is in transit and should arrive in 3β5 business days." |
|
def handle_eta(o): |
|
return (f"Delivery for order #{o} typically takes 3β5 days; " |
|
f"you can track it at https://track.example.com/{o}") |
|
def handle_track(o): |
|
return f"Track order #{o} here: https://track.example.com/{o}" |
|
def handle_link(o): |
|
return f"Hereβs the latest tracking link for order #{o}: https://track.example.com/{o}" |
|
def handle_return_policy(_=None): |
|
return ("Our return policy allows returns of unused items in their original packaging " |
|
"within 30 days of receipt. Would you like me to connect you with a human agent?") |
|
def handle_gratitude(_=None): |
|
return "Youβre welcome! Is there anything else I can help with?" |
|
def handle_escalation(_=None): |
|
return "Iβm sorry, I donβt have that information. Would you like me to connect you with a human agent?" |
|
|
|
|
|
def chat_with_memory(user_input: str) -> str: |
|
global stored_order, pending_intent |
|
|
|
memory.save_context({"input": user_input}, {"output": ""}) |
|
new_o = extract_order(user_input) |
|
if new_o: |
|
stored_order = new_o |
|
if pending_intent in ("status","eta","track","link"): |
|
fn = {"status":handle_status,"eta":handle_eta, |
|
"track":handle_track,"link":handle_link}[pending_intent] |
|
reply = fn(stored_order) |
|
pending_intent = None |
|
memory.save_context({"input": user_input}, {"output": reply}) |
|
return reply |
|
|
|
ui = user_input.lower().strip() |
|
if any(tok in ui for tok in ["thank you","thanks","thx"]): |
|
reply = handle_gratitude() |
|
elif "return" in ui: |
|
reply = handle_return_policy() |
|
elif any(k in ui for k in ["status","where is my order","check status"]): |
|
intent = "status" |
|
elif any(k in ui for k in ["how long","eta","delivery time"]): |
|
intent = "eta" |
|
elif any(k in ui for k in ["how can i track","track my order","where is my package"]): |
|
intent = "track" |
|
elif "tracking link" in ui or "resend" in ui: |
|
intent = "link" |
|
else: |
|
intent = "fallback" |
|
|
|
if intent in ("status","eta","track","link"): |
|
if not stored_order: |
|
pending_intent = intent |
|
reply = "Sureβwhatβs your order number (e.g., #12345)?" |
|
else: |
|
reply = {"status":handle_status,"eta":handle_eta, |
|
"track":handle_track,"link":handle_link}[intent](stored_order) |
|
else: |
|
reply = handle_escalation() |
|
|
|
memory.save_context({"input": user_input}, {"output": reply}) |
|
return reply |
|
|