themissingCRAM commited on
Commit
b05ca59
·
1 Parent(s): ed6649d

agent config so they can be fast, i dont care about the accuracy

Browse files
Files changed (4) hide show
  1. Bakery_Shop_Tools.py +7 -4
  2. Constants.py +1 -1
  3. Data_Initialisation.py +1 -0
  4. app.py +15 -48
Bakery_Shop_Tools.py CHANGED
@@ -3,11 +3,13 @@ from chromadb.utils import embedding_functions
3
  from langchain.docstore.document import Document
4
  from smolagents import Tool, tool
5
  from sqlalchemy import (
6
- Engine, text
7
  )
8
  import os
9
  from smolagents import HfApiModel
10
- ENGINE:Engine|None = None
 
 
11
 
12
  @tool
13
  def bakery_ingredient_order_sql_database(query: str) -> str:
@@ -86,8 +88,8 @@ class RetrieverTool(Tool):
86
  }
87
  output_type = "string"
88
 
89
- def __init__(self, docs: list[Document], model:HfApiModel,name_to_be_inserted:str,
90
- description_to_be_inserted:str,inputs_to_be_inserted:dict[dict],
91
  collection_name: str = "baking_recipes", enable_summary: bool = False,
92
  **kwargs):
93
  super().__init__(**kwargs)
@@ -112,6 +114,7 @@ class RetrieverTool(Tool):
112
  name = name_to_be_inserted
113
  description = description_to_be_inserted
114
  inputs = inputs_to_be_inserted
 
115
  def forward(self, query: str) -> str:
116
  assert isinstance(query, str), "Your search query must be a string"
117
  docs = self.collection.query(query_texts=[query], n_results=20)
 
3
  from langchain.docstore.document import Document
4
  from smolagents import Tool, tool
5
  from sqlalchemy import (
6
+ Engine, text
7
  )
8
  import os
9
  from smolagents import HfApiModel
10
+
11
+ ENGINE: Engine | None = None
12
+
13
 
14
  @tool
15
  def bakery_ingredient_order_sql_database(query: str) -> str:
 
88
  }
89
  output_type = "string"
90
 
91
+ def __init__(self, docs: list[Document], model: HfApiModel, name_to_be_inserted: str,
92
+ description_to_be_inserted: str, inputs_to_be_inserted: dict[dict],
93
  collection_name: str = "baking_recipes", enable_summary: bool = False,
94
  **kwargs):
95
  super().__init__(**kwargs)
 
114
  name = name_to_be_inserted
115
  description = description_to_be_inserted
116
  inputs = inputs_to_be_inserted
117
+
118
  def forward(self, query: str) -> str:
119
  assert isinstance(query, str), "Your search query must be a string"
120
  docs = self.collection.query(query_texts=[query], n_results=20)
Constants.py CHANGED
@@ -39,7 +39,7 @@ LEGAL_RAG_TOOL_INPUTS = {
39
  }
40
  }
41
 
42
- #Data
43
  BAKERY_ORDERS_DATA = [
44
  {
45
  "order_id": 1,
 
39
  }
40
  }
41
 
42
+ # Data
43
  BAKERY_ORDERS_DATA = [
44
  {
45
  "order_id": 1,
Data_Initialisation.py CHANGED
@@ -27,6 +27,7 @@ from Constants import (
27
  )
28
  from Bakery_Shop_Tools import RetrieverTool
29
 
 
30
  def init_db():
31
  _engine = create_engine("sqlite:///sqlite3.sqlite")
32
  metadata_obj = MetaData()
 
27
  )
28
  from Bakery_Shop_Tools import RetrieverTool
29
 
30
+
31
  def init_db():
32
  _engine = create_engine("sqlite:///sqlite3.sqlite")
33
  metadata_obj = MetaData()
app.py CHANGED
@@ -20,52 +20,18 @@ from Constants import (
20
  LAW_QUERY1, SQL_QUERY2, )
21
 
22
  from Bakery_Shop_Tools import bakery_ingredient_order_sql_database
23
- from Data_Initialisation import init_rag,init_db
24
-
25
- # # Get your own keys from https://cloud.langfuse.com
26
- # LANGFUSE_PUBLIC_KEY = "pk-lf-..."
27
- # LANGFUSE_SECRET_KEY = "sk-lf-..."
28
- # os.environ["LANGFUSE_PUBLIC_KEY"] = LANGFUSE_PUBLIC_KEY
29
- # os.environ["LANGFUSE_SECRET_KEY"] = LANGFUSE_SECRET_KEY
30
- # os.environ["LANGFUSE_HOST"] = "https://cloud.langfuse.com" # 🇪🇺 EU region example
31
- #
32
- # LANGFUSE_AUTH = base64.b64encode(
33
- # f"{LANGFUSE_PUBLIC_KEY}:{LANGFUSE_SECRET_KEY}".encode()
34
- # ).decode()
35
- #
36
- # os.environ["OTEL_EXPORTER_OTLP_ENDPOINT"] = os.environ.get("LANGFUSE_HOST") + "/api/public/otel"
37
- # os.environ["OTEL_EXPORTER_OTLP_HEADERS"] = f"Authorization=Basic {LANGFUSE_AUTH}"
38
- # from opentelemetry.sdk.trace import TracerProvider
39
- # from openinference.instrumentation.smolagents import SmolagentsInstrumentor
40
- # from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
41
- # from opentelemetry.sdk.trace.export import SimpleSpanProcessor
42
- #
43
- # # Create a TracerProvider for OpenTelemetry
44
- # trace_provider = TracerProvider()
45
- #
46
- # # Add a SimpleSpanProcessor with the OTLPSpanExporter to send traces
47
- # trace_provider.add_span_processor(SimpleSpanProcessor(OTLPSpanExporter()))
48
- #
49
- # # Set the global default tracer provider
50
- # from opentelemetry import trace
51
- #
52
- # trace.set_tracer_provider(trace_provider)
53
- # tracer = trace.get_tracer(__name__)
54
- #
55
- # # Instrument smolagents with the configured provider
56
- # SmolagentsInstrumentor().instrument(tracer_provider=trace_provider)
57
  @spaces.GPU
58
  def dummy():
59
  pass
60
 
61
 
62
-
63
  if __name__ == "__main__":
64
- # preparing tools
65
  engine = init_db()
66
  Bakery_Shop_Tools.ENGINE = engine
67
 
68
- # model
69
  model_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
70
  # model_id="meta-llama/Llama-3.2-3B-Instruct",
71
 
@@ -106,9 +72,9 @@ if __name__ == "__main__":
106
  baking_recipes_retriever_agent = CodeAgent(
107
  tools=[baking_recipes_retriever_tool],
108
  model=model,
109
- max_steps=3,
110
  verbosity_level=1,
111
- planning_interval=1,
112
  name="baking_recipe_retriever_agent",
113
  description=
114
  '''
@@ -117,13 +83,13 @@ if __name__ == "__main__":
117
  not json or python or any other programming codes.
118
  '''
119
  )
120
- # more complicated so there are bit more max steps and planning intervals
121
  legals_retriever_agent = CodeAgent(
122
  tools=[legals_retriever_tool],
123
  model=model,
124
- max_steps=3,
125
- verbosity_level=2,
126
- planning_interval=2,
127
  name="legals_retriever_agent",
128
  description=
129
  '''
@@ -132,11 +98,11 @@ if __name__ == "__main__":
132
  not json or python or any other programming codes.
133
  '''
134
  )
135
- # more complicated so there are bit more max steps and planning intervals
136
  retriever_agent = CodeAgent(
137
  tools=[],
138
  model=model,
139
- max_steps=3,
140
  verbosity_level=2,
141
  planning_interval=1,
142
  managed_agents=[baking_recipes_retriever_agent, legals_retriever_agent],
@@ -198,7 +164,7 @@ if __name__ == "__main__":
198
  time.sleep(2)
199
 
200
  message = message + " " + transcriber({"sampling_rate": sr, "raw": y})["text"]
201
- print("new msseage:", message)
202
  return message
203
 
204
 
@@ -249,7 +215,6 @@ if __name__ == "__main__":
249
  rag_q_button = gr.Button(RAG_QUESTION)
250
  sql_q1_button = gr.Button(SQL_QUERY)
251
  sql_q2_button = gr.Button(SQL_QUERY2)
252
-
253
  combi_button = gr.Button(RAG_QUESTION + " " + SQL_QUERY)
254
  legal_q_button = gr.Button(LAW_QUERY1)
255
  gr.Markdown("This gradio app has many accordion UI components that you can click to expand the UI component")
@@ -261,9 +226,10 @@ if __name__ == "__main__":
261
  mbox_submit_event = message_box.submit(enter_message,
262
  [message_box, chatbot],
263
  [message_box, chatbot])
264
-
265
  audio_stream = audio_interface.change(transcribe, inputs=[audio_interface, message_box],
266
  outputs=[message_box])
 
 
267
  rag_q_click_event = rag_q_button.click(enter_message,
268
  [rag_q_button, chatbot],
269
  [message_box, chatbot])
@@ -278,6 +244,7 @@ if __name__ == "__main__":
278
  [combi_button, chatbot],
279
  [message_box, chatbot])
280
 
 
281
  legal_q_click_event = legal_q_button.click(enter_message,
282
  [legal_q_button, chatbot],
283
  [message_box, chatbot])
 
20
  LAW_QUERY1, SQL_QUERY2, )
21
 
22
  from Bakery_Shop_Tools import bakery_ingredient_order_sql_database
23
+ from Data_Initialisation import init_rag, init_db
24
+
25
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  @spaces.GPU
27
  def dummy():
28
  pass
29
 
30
 
 
31
  if __name__ == "__main__":
 
32
  engine = init_db()
33
  Bakery_Shop_Tools.ENGINE = engine
34
 
 
35
  model_id = "Qwen/Qwen2.5-Coder-32B-Instruct"
36
  # model_id="meta-llama/Llama-3.2-3B-Instruct",
37
 
 
72
  baking_recipes_retriever_agent = CodeAgent(
73
  tools=[baking_recipes_retriever_tool],
74
  model=model,
75
+ max_steps=2,
76
  verbosity_level=1,
77
+ # planning_interval=1,
78
  name="baking_recipe_retriever_agent",
79
  description=
80
  '''
 
83
  not json or python or any other programming codes.
84
  '''
85
  )
86
+ # more complicated so there are a bit more max steps and planning intervals
87
  legals_retriever_agent = CodeAgent(
88
  tools=[legals_retriever_tool],
89
  model=model,
90
+ max_steps=2,
91
+ verbosity_level=1,
92
+ planning_interval=1,
93
  name="legals_retriever_agent",
94
  description=
95
  '''
 
98
  not json or python or any other programming codes.
99
  '''
100
  )
101
+ # more complicated so there are a bit more max steps and planning intervals
102
  retriever_agent = CodeAgent(
103
  tools=[],
104
  model=model,
105
+ max_steps=2,
106
  verbosity_level=2,
107
  planning_interval=1,
108
  managed_agents=[baking_recipes_retriever_agent, legals_retriever_agent],
 
164
  time.sleep(2)
165
 
166
  message = message + " " + transcriber({"sampling_rate": sr, "raw": y})["text"]
167
+ print("new message:", message)
168
  return message
169
 
170
 
 
215
  rag_q_button = gr.Button(RAG_QUESTION)
216
  sql_q1_button = gr.Button(SQL_QUERY)
217
  sql_q2_button = gr.Button(SQL_QUERY2)
 
218
  combi_button = gr.Button(RAG_QUESTION + " " + SQL_QUERY)
219
  legal_q_button = gr.Button(LAW_QUERY1)
220
  gr.Markdown("This gradio app has many accordion UI components that you can click to expand the UI component")
 
226
  mbox_submit_event = message_box.submit(enter_message,
227
  [message_box, chatbot],
228
  [message_box, chatbot])
 
229
  audio_stream = audio_interface.change(transcribe, inputs=[audio_interface, message_box],
230
  outputs=[message_box])
231
+
232
+ # query examples events
233
  rag_q_click_event = rag_q_button.click(enter_message,
234
  [rag_q_button, chatbot],
235
  [message_box, chatbot])
 
244
  [combi_button, chatbot],
245
  [message_box, chatbot])
246
 
247
+ # cancels
248
  legal_q_click_event = legal_q_button.click(enter_message,
249
  [legal_q_button, chatbot],
250
  [message_box, chatbot])