WilliamGazeley commited on
Commit
e40d8d8
·
1 Parent(s): 1d120a6

Fix test pathing

Browse files
Files changed (4) hide show
  1. pytest.ini +2 -1
  2. src/app.py +34 -10
  3. src/functioncall.py +2 -2
  4. tests/test_quality.py +7 -4
pytest.ini CHANGED
@@ -1 +1,2 @@
1
- pythonpath=.,src
 
 
1
+ [pytest]
2
+ pythonpath=src
src/app.py CHANGED
@@ -13,7 +13,7 @@ def init_llm():
13
  return llm
14
 
15
 
16
- def get_response(prompt):
17
  try:
18
  return llm.generate_function_call(
19
  prompt, config.chat_template, config.num_fewshot, config.max_depth
@@ -22,11 +22,12 @@ def get_response(prompt):
22
  return f"An error occurred: {str(e)}"
23
 
24
 
25
- def get_output(context, user_input):
 
26
  try:
27
  config.status.update(label=":bulb: Preparing answer..")
28
  script_dir = os.path.dirname(os.path.abspath(__file__))
29
- prompt_path = os.path.join(script_dir, 'prompt_assets', 'output_sys_prompt.yml')
30
  prompt_schema = llm.prompter.read_yaml_file(prompt_path)
31
  sys_prompt = (
32
  llm.prompter.format_yaml_prompt(prompt_schema, dict())
@@ -41,6 +42,30 @@ def get_output(context, user_input):
41
  except Exception as e:
42
  return f"An error occurred: {str(e)}"
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
  def main():
46
  st.title("LLM-ADE 9B Demo")
@@ -51,26 +76,25 @@ def main():
51
  if input_text:
52
  with st.status("Generating response...") as status:
53
  config.status = status
54
- agent_resp = get_response(input_text)
55
- st.write(get_output(agent_resp, input_text))
56
  config.status.update(label="Finished!", state="complete", expanded=True)
57
  else:
58
  st.warning("Please enter some text to generate a response.")
59
 
60
 
61
- llm = init_llm()
62
-
63
-
64
  def main_headless(prompt: str):
65
  start = time()
66
- agent_resp = get_response(prompt)
67
- print("\033[94m" + get_output(agent_resp, prompt) + "\033[0m")
68
  print(f"Time taken: {time() - start:.2f}s\n" + "-" * 20)
69
 
70
 
 
 
 
71
  if __name__ == "__main__":
72
  if config.headless:
73
  import fire
 
74
  fire.Fire(main_headless)
75
  else:
76
  main()
 
13
  return llm
14
 
15
 
16
+ def function_agent(prompt):
17
  try:
18
  return llm.generate_function_call(
19
  prompt, config.chat_template, config.num_fewshot, config.max_depth
 
22
  return f"An error occurred: {str(e)}"
23
 
24
 
25
+ def output_agent(context, user_input):
26
+ """Takes the output of the RAG and generates a final response."""
27
  try:
28
  config.status.update(label=":bulb: Preparing answer..")
29
  script_dir = os.path.dirname(os.path.abspath(__file__))
30
+ prompt_path = os.path.join(script_dir, "prompt_assets", "output_sys_prompt.yml")
31
  prompt_schema = llm.prompter.read_yaml_file(prompt_path)
32
  sys_prompt = (
33
  llm.prompter.format_yaml_prompt(prompt_schema, dict())
 
42
  except Exception as e:
43
  return f"An error occurred: {str(e)}"
44
 
45
+ def query_agent(prompt):
46
+ """Modifies the prompt and runs inference on it."""
47
+ try:
48
+ config.status.update(label=":brain: Starting inference..")
49
+ script_dir = os.path.dirname(os.path.abspath(__file__))
50
+ prompt_path = os.path.join(script_dir, "prompt_assets", "output_sys_prompt.yml")
51
+ prompt_schema = llm.prompter.read_yaml_file(prompt_path)
52
+ sys_prompt = llm.prompter.format_yaml_prompt(prompt_schema, dict())
53
+ convo = [
54
+ {"role": "system", "content": sys_prompt},
55
+ {"role": "user", "content": prompt},
56
+ ]
57
+ response = llm.run_inference(convo)
58
+ return response
59
+ except Exception as e:
60
+ return f"An error occurred: {str(e)}"
61
+
62
+
63
+ def get_response(input_text: str):
64
+ """This is the main function that generates the final response."""
65
+ agent_resp = function_agent(input_text)
66
+ output = output_agent(agent_resp, input_text)
67
+ return output
68
+
69
 
70
  def main():
71
  st.title("LLM-ADE 9B Demo")
 
76
  if input_text:
77
  with st.status("Generating response...") as status:
78
  config.status = status
79
+ st.write(get_response(input_text))
 
80
  config.status.update(label="Finished!", state="complete", expanded=True)
81
  else:
82
  st.warning("Please enter some text to generate a response.")
83
 
84
 
 
 
 
85
  def main_headless(prompt: str):
86
  start = time()
87
+ print("\033[94m" + get_response(prompt) + "\033[0m")
 
88
  print(f"Time taken: {time() - start:.2f}s\n" + "-" * 20)
89
 
90
 
91
+ llm = init_llm()
92
+
93
+
94
  if __name__ == "__main__":
95
  if config.headless:
96
  import fire
97
+
98
  fire.Fire(main_headless)
99
  else:
100
  main()
src/functioncall.py CHANGED
@@ -39,8 +39,7 @@ class ModelInference:
39
 
40
  def process_completion_and_validate(self, completion, chat_template):
41
  if completion:
42
- # completion = f"<tool_call>\n{completion}\n</tool_call>"
43
- breakpoint()
44
  validation, tool_calls, error_message = validate_and_extract_tool_calls(completion)
45
 
46
  if validation:
@@ -85,6 +84,7 @@ class ModelInference:
85
 
86
  def recursive_loop(prompt, completion, depth):
87
  nonlocal max_depth
 
88
  tool_calls, assistant_message, error_message = self.process_completion_and_validate(completion, chat_template)
89
  prompt.append({"role": "assistant", "content": assistant_message})
90
 
 
39
 
40
  def process_completion_and_validate(self, completion, chat_template):
41
  if completion:
42
+ # completion = f"<tool_call>\n{completion}\n</tool_call>"]
 
43
  validation, tool_calls, error_message = validate_and_extract_tool_calls(completion)
44
 
45
  if validation:
 
84
 
85
  def recursive_loop(prompt, completion, depth):
86
  nonlocal max_depth
87
+ breakpoint()
88
  tool_calls, assistant_message, error_message = self.process_completion_and_validate(completion, chat_template)
89
  prompt.append({"role": "assistant", "content": assistant_message})
90
 
tests/test_quality.py CHANGED
@@ -1,10 +1,11 @@
1
  import sys
2
  import json
3
  from io import StringIO
 
4
 
5
  def test_quality():
6
  """Tests if the expected functions and values are used"""
7
- with open("qa_questions.json") as f:
8
  qs = json.load(f)
9
 
10
  for q in qs:
@@ -12,11 +13,13 @@ def test_quality():
12
  stdout = StringIO()
13
  sys.stdout = stdout
14
 
15
- for include in q['includes']:
 
 
16
  assert include in stdout.getvalue(), f"Expected {include} in output"
17
- for exclude in q['excludes']:
18
  assert exclude not in stdout.getvalue(), f"Expected {exclude} not in output"
19
- for function in q['functions']:
20
  assert f"Invoking function call {function}" in stdout.getvalue(), f"{function} was not invoked"
21
 
22
  # Restore stdout
 
1
  import sys
2
  import json
3
  from io import StringIO
4
+ from app import get_response
5
 
6
  def test_quality():
7
  """Tests if the expected functions and values are used"""
8
+ with open("tests/qa_questions.json") as f:
9
  qs = json.load(f)
10
 
11
  for q in qs:
 
13
  stdout = StringIO()
14
  sys.stdout = stdout
15
 
16
+ get_response(q["question"])
17
+
18
+ for include in q["expecteds"]["includes"]:
19
  assert include in stdout.getvalue(), f"Expected {include} in output"
20
+ for exclude in q["expecteds"]["excludes"]:
21
  assert exclude not in stdout.getvalue(), f"Expected {exclude} not in output"
22
+ for function in q["expecteds"]["functions"]:
23
  assert f"Invoking function call {function}" in stdout.getvalue(), f"{function} was not invoked"
24
 
25
  # Restore stdout