import os import uuid from datetime import datetime, timedelta from json import tool from typing import Annotated, Literal, Optional, TypedDict import gradio as gr import pytz from langchain.pydantic_v1 import BaseModel, Field from langchain.tools import BaseTool, StructuredTool, tool from langchain.tools.retriever import create_retriever_tool from langchain_anthropic import ChatAnthropic from langchain_community.document_loaders import TextLoader from langchain_community.vectorstores import Chroma from langchain_core.messages import BaseMessage, HumanMessage from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder from langchain_core.tools import tool # from langchain_fireworks import ChatFireworks from langchain_openai import ChatOpenAI, OpenAIEmbeddings from langchain_text_splitters import RecursiveCharacterTextSplitter # from langchain_together import ChatTogether # from langgraph.checkpoint import MemorySaver from langgraph.checkpoint.sqlite import SqliteSaver from langgraph.graph import END, START, MessagesState, StateGraph from langgraph.graph.message import add_messages from langgraph.prebuilt import ToolNode from langsmith import traceable from num2words import num2words from typing_extensions import TypedDict OPENAI_API_KEY = "sk-proj-15FiYvsDSNLfXbc8SUPdT3BlbkFJoWcy0tepFI9FS18oLbnc" def get_today_plus_n_days_date(num_days): date_object = datetime.now().astimezone(pytz.timezone("US/Pacific")) + timedelta( days=num_days ) day = date_object.strftime("%d") ordinal_day = num2words(day, ordinal=True) return date_object.strftime("the {} of %B, %Y").format(ordinal_day) def get_current_time(): return datetime.now().astimezone(pytz.timezone("US/Pacific")).strftime("%I:%M%p %Z") def get_today_date(): return get_today_plus_n_days_date(0) thread_id = str(uuid.uuid4()) westlake_policy = TextLoader("./westlake-policy.txt").load() # docs = [westlake_policy] text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200) splits = text_splitter.split_documents(westlake_policy) vectorstore = Chroma.from_documents( documents=splits, embedding=OpenAIEmbeddings(api_key=OPENAI_API_KEY) ) retriever = vectorstore.as_retriever() retriever_tool = create_retriever_tool( retriever, "retrieve_westlake_policy", "Search and return information about Westlake company policy", ) def clear(): global call_metadata, global_entities, thread_id print("Clear button clicked") thread_id = str(uuid.uuid4()) memory = SqliteSaver.from_conn_string(":memory:") os.environ["LANGCHAIN_TRACING_V2"] = "true" os.environ["LANGCHAIN_API_KEY"] = "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" ## DEFINE TOOLS ## class EscalationInput(BaseModel): reason_escalate_to_human: Optional[str] = Field( description="string representing the reason the call needs to be escalated to a human agent, e.g. 'conversation not progressing'" ) @tool(args_schema=EscalationInput) @traceable(run_type="tool", name="Escalate To Human") def escalate_to_human(**args): """If you cannot assist the customer based on the information you have, or if the customer is unhappy, or if the customer requests to be transferred""" args = EscalationInput(**args) print(f"Tool {escalate_to_human} called with params: {escalate_to_human}") if args.reason_escalate_to_human: return "Notify the customer that you are escalating to a live agent." else: return "Ask the customer if they can please provide a precise reason why you would like to speak with an agent." class PaymentInput(BaseModel): desired_payment_amount: Optional[float] = Field( description="the amount that the customer would like to pay." ) desired_payment_date: Optional[str] = Field( description="the date the customer would like to pay. e.g. next Tuesday, or tomorrow, or 2 weeks from now" ) use_payment_method_on_file: Optional[bool] = Field( description="True if the customer wants to use the Debit Card ending in 9123 on file, False if the customer does not want to use the Debit Card ending in 9123 on file." ) # use_new_debit_card: Optional[bool] = Field( # description="True if the customer wants to use a new Debit Card not on file, False if the customer does not want to use a new Debit Card." # ) # debit_card_number: Optional[str] = Field( # description="the customer's Debit Card number." # ) # debit_card_expiration_date: Optional[str] = Field( # description="the customer's Debit Card expiration date." # ) # debit_card_cvv: Optional[str] = Field(description="the customer's Debit Card CVV.") # use_new_bank_account: Optional[bool] = Field( # description="True if the customer wants to use a new Bank Account not on file, False if the customer does not want to use a new Bank Account." # ) # bank_account_number: Optional[str] = Field( # description="the customer's Bank Account number." # ) # bank_routing_number: Optional[str] = Field( # description="the customer's Bank Routing number." # ) # bank_type: Optional[str] = Field( # description="the customer's Bank Type e.g. Checking, Savings." # ) authorize_final_disclaimer: Optional[bool] = Field( description="True if the customer agrees to the final payment disclaimer, False if the customer does not agree to the final disclaimer." ) @tool(args_schema=PaymentInput) @traceable(run_type="tool", name="Make a Payment") def make_payment(**args): """Call this tool if the customer would like to make a payment over the phone with you, either with a payment method on file or a new payment method. You should call this tool if the user is indicating they want to provide a new payment method not on file. This tool is not to be called for cancelling a payment, issuing a refund, or changing the date of a scheduled payment. You should pass ALL known parameters from previous conversation history to the tool.""" args = PaymentInput(**args) print("Called make_payment tool") if not args.desired_payment_amount: return "Ask the customer to confirm the payment amount." if not args.desired_payment_date: return "Ask the customer to confirm the payment date." if args.use_payment_method_on_file is None: return "Ask the customer to confirm if they would like to use the Debit Card ending in 9123 on file." if not args.use_payment_method_on_file: if args.use_new_debit_card is None and args.use_new_bank_account is None: return "You cannot collect a new payment method over the phone. Tell the customer that you are escalating to a live agent." # return "Ask the customer if they would like to use a new Debit Card or a new Bank Account." # if args.use_new_debit_card: # if not args.debit_card_number: # return "Ask the customer to provide their Debit Card number." # if not args.debit_card_expiration_date: # return "Ask the customer to provide their Debit Card expiration date." # if not args.debit_card_cvv: # return "Ask the customer to provide their Debit Card CVV." # if args.use_new_bank_account: # if not args.bank_account_number: # return "Ask the customer to provide their Bank Account number." # if not args.bank_routing_number: # return "Ask the customer to provide their Bank Routing number." # if not args.bank_type: # return "Ask the customer to provide their Bank Type e.g. Checking, Savings." # payment_method = ( # "Debit Card ending in 9123 on file" # if args.use_payment_method_on_file # else "new Debit Card" if args.use_new_debit_card else "new Bank Account" # ) payment_method = "Debit Card ending in 9123 on file" if args.authorize_final_disclaimer is None: return f"DETERMINISTIC Great! I'll read a quick disclaimer. Brad Thompson, today, {get_today_date()} you are authorizing a payment in the amount of ${args.desired_payment_amount}, plus a $5 processing fee, dated on {args.desired_payment_date} using your {payment_method}. By authorizing this payment, you agree that you are the account holder or authorized user. Please say yes to proceed." if not args.authorize_final_disclaimer: return f"Tell the customer you have noted on their account that they plan to pay ${args.desired_payment_amount} on {args.desired_payment_date}." return f"Tell the customer you have processed their payment of ${args.desired_payment_amount} on {args.desired_payment_date} using the Debit Card ending in 9123 on file." class PromiseInput(BaseModel): desired_payment_amount: Optional[float] = Field( description="the amount that the customer intends to pay." ) desired_payment_date: Optional[str] = Field( description="the date the customer to pay." ) desired_payment_method: Optional[bool] = Field( description="how the customer plans to make a payment in the future e.g. in cash, pay near me, debit card, phoning it in, ACH, using the app" ) @tool(args_schema=PromiseInput) @traceable(run_type="tool", name="Notate a Promise to Pay") def notate_promise_to_pay(**args): """If the customer would like to notify you of a payment they plan to make in the future, but does not require you to process the payment immediately over the phone""" args = PromiseInput(**args) print("Called notate_promise_to_pay tool") if not args.desired_payment_amount: return "Ask the customer to how much they plan on paying." if not args.desired_payment_date: return "Ask the customer to when they plan to make a payment." if not args.desired_payment_method: return "Ask the customer how they plan to make the payment." return f"Tell the customer you have noted on their account that they plan to pay ${args.desired_payment_amount} on {args.desired_payment_date} using {args.desired_payment_method}." class CompanyPolicy(BaseModel): company_policy_topic: str = Field( description="the specific company policy is the customer is asking about e.g. late fees, insurance, GAP, titles, grace period, payment methods accepted" ) @tool(args_schema=CompanyPolicy) @traceable(run_type="tool", name="Fetch Company Policy") def fetch_company_policy(**args): """If the customer is asking about or talking about Westlake's company policies""" args = CompanyPolicy(**args) print("Called fetch company policy tool") return """Westlake Company Policy: Westlake Financial does not have a grace period. Interest continues to accrue once the due date has passed. You may inform the customer that late fees are charged 10 days after their due date. You do not know the amount that will be charged as it is dependent on the customer's contract. Do not ask the customer for any account information or contract information or their due date. In general, Westlake accepts the following payment methods: debit card, bank account, MyAccount mobile app, Moneygram by using code 2603, check via mail, or cash by visiting a Pay Near Me location. Westlake DOES NOT accept credit card payments." Payments will post one business day after they have been received. If there is a delay in processing the payment on Westlake's end, the payment will be credited as of the date it was received. The fee for processing a payment over the phone is $5. """ class CustomerInfo(BaseModel): customer_info_topic: Optional[str] = Field( description="the specific company policy is the customer is asking about e.g. balance, pay off amount, late fees, vehicle on file, next due date," ) @tool(args_schema=CustomerInfo) @traceable(run_type="tool", name="Fetch Customer Information") def fetch_customer_information(**args): """If the customer is asking about or talking about information regarding their account, like balance, vehicle on file,""" args = CustomerInfo(**args) print(f"Called fetch company policy tool with topic: {args.customer_info_topic}") return """Customer Account Information: SSN: 1234 DOB: 1987-07-21 Customer Name: Brad Thompson Account Number: 1234567 Account Balance: $11000 Pay off Amount: $13500 Remaining Length of Loan: 32 months Total Delinquent Due Amount: $250 Total Delinquent Due Amount Without Late Charges: $230 Regular Monthly Payment: $230 Vehicle: 2018 Dodge Charger Last Received Payment Date: 2024-06-24 Last Received Payment Amount: $230 Delinquent Days: 3 Payment Method on File: Debit Card ending in 9123 Latest Allowed Payment Date: 2024-09-01""" tools = [ escalate_to_human, make_payment, notate_promise_to_pay, # fetch_company_policy, fetch_customer_information, retriever_tool, ] tool_node = ToolNode(tools) ## END DEFINE TOOLS ## class State(TypedDict): messages: Annotated[list, add_messages] # Define the function that determines whether to continue or not @traceable(name="Should Continue") def should_continue(state: MessagesState) -> Literal["tools", END]: messages = state["messages"] last_message = messages[-1] print(f"inside should_continue, last_message = {last_message}") if len(messages) >= 2: second_last_message = messages[-2] print(f"second_last_message = {second_last_message}") # If the LLM makes a tool call, then we route to the "tools" node if not last_message.tool_calls: return END print("Tool call detected") return "tools" # Define the function that determines whether to continue or not @traceable(name="Should Continue") def should_continue_tools(state: MessagesState) -> Literal["tools", END, "chatbot"]: messages = state["messages"] last_message = messages[-1] print(f"inside should_continue_tools, last_message = {last_message}") if "DETERMINISTIC" in last_message.content: last_message.content = last_message.content.replace("DETERMINISTIC", "").strip() return END return "chatbot" primary_assistant_prompt = ChatPromptTemplate.from_messages( [ ( "system", "You are Taylor, a virtual, helpful, customer support assistant for Westlake Financial. " "Use the provided tools to search for customer information, company policies, and other information to assist the user's queries. You can only assist the customer with the tools provided. If the customer is asking for assistance on a task that cannot be accommodated by the tools at your disposal, offer to transfer them. " "When searching, be persistent. Expand your query bounds if the first search returns no results. " "If a search comes up empty, expand your search before giving up." "Remember, you are not a real person, you are an artifical intelligence designed to help Westlake's customers.", ), MessagesPlaceholder(variable_name="messages"), ] ) llm = primary_assistant_prompt | ChatOpenAI( api_key=OPENAI_API_KEY, model="gpt-4", max_tokens=500, ).bind_tools(tools) # llm = primary_assistant_prompt | ChatAnthropic( # api_key="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", # model="claude-3-sonnet-20240229", # temperature=1, # ).bind_tools(tools) @traceable(name="Chatbot") def chatbot(state: State): return {"messages": [llm.invoke(state["messages"])]} graph_builder = StateGraph(State) graph_builder.add_node("chatbot", chatbot) graph_builder.add_node("tools", tool_node) graph_builder.set_entry_point("chatbot") graph_builder.add_conditional_edges( "chatbot", should_continue, ) graph_builder.add_conditional_edges("tools", should_continue_tools) # graph_builder.add_edge("tools", "chatbot") # Needs to be removed, because should_continue_tools will be the decider of the next step - might be chatbot OR end graph = graph_builder.compile(checkpointer=memory) # interrupt_after=["tools"] async def respond(message, history): global thread_id if len(history) == 0: clear() message = message.strip() print(f"Inside Gradio respond\nmessage = {message}\n") config = {"configurable": {"thread_id": thread_id}} last_message = "" for event in graph.stream( {"messages": ("user", message)}, config, stream_mode="values" ): event["messages"][-1].pretty_print() last_message = event["messages"][-1] return last_message.content with gr.Blocks() as demo: clear() clear_btn = gr.Button("Clear", render=False) clear_btn.click(fn=clear, api_name="clear") chat = gr.ChatInterface( respond, clear_btn=clear_btn, # description="Refresh page to reset chat!", retry_btn=None, undo_btn=None, # clear_btn=None, chatbot=gr.Chatbot( render=False, # value=[[None, first_message]], height=500, ), ) demo.queue().launch(share=False)