suwesh commited on
Commit
6c8fb3a
·
verified ·
1 Parent(s): 3c4359a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +147 -2
app.py CHANGED
@@ -1,6 +1,7 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
4
  #import sqlite3
5
  import json
6
  from db_setup import setup_database
@@ -16,7 +17,152 @@ modelpath = "Salesforce/xLAM-1b-fc-r"
16
  model = AutoModelForCausalLM.from_pretrained(modelpath, torch_dtype="auto", trust_remote_code=True)
17
  tokenizer = AutoTokenizer.from_pretrained(modelpath)
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def respond(
22
  message,
@@ -36,7 +182,7 @@ def respond(
36
 
37
  messages.append({"role": "user", "content": message})
38
 
39
- response = ""
40
 
41
  for message in client.chat_completion(
42
  messages,
@@ -57,7 +203,6 @@ For information on how to customize the ChatInterface, peruse the gradio docs: h
57
  demo = gr.ChatInterface(
58
  respond,
59
  additional_inputs=[
60
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
61
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
62
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
63
  gr.Slider(
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
+ import torch
5
  #import sqlite3
6
  import json
7
  from db_setup import setup_database
 
17
  model = AutoModelForCausalLM.from_pretrained(modelpath, torch_dtype="auto", trust_remote_code=True)
18
  tokenizer = AutoTokenizer.from_pretrained(modelpath)
19
 
20
+ #=============prompt template and task instructions==============
21
+ # Please use our provided instruction prompt for best performance
22
+ task_instruction = """
23
+ You are an expert in composing functions. You are given a question and a set of possible functions.
24
+ Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
25
+ If none of the functions can be used, point it out and refuse to answer.
26
+ If the given question lacks the parameters required by the function, also point it out.
27
+ """.strip()
28
+ format_instruction = """
29
+ The output MUST strictly adhere to the following JSON format, and NO other text MUST be included.
30
+ The example format is as follows. Please make sure the parameter type is correct. If no function call is needed, please make tool_calls an empty list '[]'.
31
+ ```
32
+ {
33
+ "tool_calls": [
34
+ {"name": "func_name1", "arguments": {"argument1": "value1", "argument2": "value2"}},
35
+ ... (more tool calls as required)
36
+ ]
37
+ }
38
+ ```
39
+ """.strip()##==output format
40
+ #=============APIs and Functions Metadata========================
41
+ get_weather_api = {
42
+ "name": "get_weather",
43
+ "description": "Get the current weather for a location",
44
+ "parameters": {
45
+ "type": "object",
46
+ "properties": {
47
+ "location": {
48
+ "type": "string",
49
+ "description": "The city and state, e.g. San Francisco, New York"
50
+ },
51
+ "unit": {
52
+ "type": "string",
53
+ "enum": ["celsius", "fahrenheit"],
54
+ "description": "The unit of temperature to return"
55
+ }
56
+ },
57
+ "required": ["location"]
58
+ }
59
+ }
60
 
61
+ search_api = {
62
+ "name": "search",
63
+ "description": "Search for information on the internet",
64
+ "parameters": {
65
+ "type": "object",
66
+ "properties": {
67
+ "query": {
68
+ "type": "string",
69
+ "description": "The search query, e.g. 'latest news on AI'"
70
+ }
71
+ },
72
+ "required": ["query"]
73
+ }
74
+ }
75
+
76
+ search_loanapplication = {
77
+ "name": "searchLA",
78
+ "description": "Search for Loan Application status",
79
+ "parameters": {
80
+ "type": "object",
81
+ "properties": {
82
+ "loan_application_id": {
83
+ "type": "alphanumeric string",
84
+ "description": "The unique identifier for a loan application, eg: LA1234"
85
+ },
86
+ "phone_number": {
87
+ "type": "string",
88
+ "description": "The phone number associated with the loan application"
89
+ }
90
+ },
91
+ "required": ["loan_application_id", "phone_number"]
92
+ }
93
+ }
94
+
95
+ openai_format_tools = [search_api, search_loanapplication, get_weather_api]
96
+ # Helper function to convert openai format tools to our more concise xLAM format
97
+ def convert_to_xlam_tool(tools):
98
+ ''''''
99
+ if isinstance(tools, dict):
100
+ return {
101
+ "name": tools["name"],
102
+ "description": tools["description"],
103
+ "parameters": {k: v for k, v in tools["parameters"].get("properties", {}).items()}
104
+ }
105
+ elif isinstance(tools, list):
106
+ return [convert_to_xlam_tool(tool) for tool in tools]
107
+ else:
108
+ return tools
109
+ #=========prompt builder====================================
110
+ # Helper function to build the input prompt for our model
111
+ def build_prompt(task_instruction: str, format_instruction: str, xlam_format_tools: list, query: str):
112
+ prompt = f"[BEGIN OF TASK INSTRUCTION]\n{task_instruction}\n[END OF TASK INSTRUCTION]\n\n"
113
+ prompt += f"[BEGIN OF AVAILABLE TOOLS]\n{json.dumps(xlam_format_tools)}\n[END OF AVAILABLE TOOLS]\n\n"
114
+ prompt += f"[BEGIN OF FORMAT INSTRUCTION]\n{format_instruction}\n[END OF FORMAT INSTRUCTION]\n\n"
115
+ prompt += f"[BEGIN OF QUERY]\n{query}\n[END OF QUERY]\n\n"
116
+ return prompt
117
+
118
+ def to_model(query):
119
+ xlam_format_tools = convert_to_xlam_tool(openai_format_tools)
120
+ content = build_prompt(task_instruction, format_instruction, xlam_format_tools, query)
121
+ #print(f"content: {content}")
122
+ messages=[
123
+ { 'role': 'user', 'content': content}
124
+ ]
125
+ inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
126
+ # tokenizer.eos_token_id is the id of <|EOT|> token
127
+ outputs = model.generate(inputs, max_new_tokens=512, do_sample=False, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
128
+ return (tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True))
129
+
130
+ def to_app(callobj):
131
+ callobject = json.loads(callobj)
132
+ callfunctions = []
133
+ callarguments = []
134
+ for tool_call in callobject['tool_calls']:
135
+ callfunctions.append(tool_call['name'])
136
+ callarguments.append(list(tool_call['arguments'].values()))
137
+ #print(f"fuctions: {callfunctions}")
138
+ #print(f"arguments: {callarguments}")
139
+ return callfunctions, callarguments
140
+ #===========sample application===================================
141
+ def application(callfunctions, callarguments):
142
+ ##los application functions
143
+ def get_weather(location):
144
+ return (print(f"weather function executed with city {location}"))
145
+ def searchLA(laid, phnumber):
146
+ query = f"""SELECT * from sfdc_la where LAid = '{laid}'"""
147
+ cursor.execute(query)
148
+ result = cursor.fetchall()
149
+ return (print(result))
150
+ losfunctions_list = ['get_weather','searchLA']
151
+ for i, functionname in enumerate(callfunctions):
152
+ if functionname in losfunctions_list:
153
+ function = globals().get(functionname) or locals().get(functionname)
154
+ if function:
155
+ arguments = callarguments[i]
156
+ out = function(*arguments)
157
+ return out
158
+ #out = application(callfunctions, callarguments)
159
+ def process_input(input_str):
160
+ if not input_str:
161
+ return "No input provided!"
162
+ model_out = to_model(input_str)
163
+ funs, args = to_app(model_out)
164
+ output_obj = application(funs, args)
165
+ return output_obj
166
 
167
  def respond(
168
  message,
 
182
 
183
  messages.append({"role": "user", "content": message})
184
 
185
+ response = process_input(message)
186
 
187
  for message in client.chat_completion(
188
  messages,
 
203
  demo = gr.ChatInterface(
204
  respond,
205
  additional_inputs=[
 
206
  gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
207
  gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
208
  gr.Slider(