jarguello76 commited on
Commit
692591b
·
verified ·
1 Parent(s): 6c8d237

Update tools.py

Browse files
Files changed (1) hide show
  1. tools.py +201 -400
tools.py CHANGED
@@ -1,404 +1,205 @@
1
- from smolagents import Tool
2
- import random
3
- from huggingface_hub import list_models
4
- import requests
5
  import os
6
- import sqlite3
7
- from googletrans import Translator
8
- from gtts import gTTS
9
- import speech_recognition as sr
10
- import cv2
11
  import numpy as np
12
- from textblob import TextBlob
13
-
14
- # Initialize the DuckDuckGo search tool
15
- # search_tool = DuckDuckGoSearchTool()
16
-
17
- class WeatherInfoTool(Tool):
18
- name = "weather_info"
19
- description = "Fetches weather information for a given location."
20
- inputs = {
21
- "location": {
22
- "type": "string",
23
- "description": "The location to get weather information for."
24
- }
25
- }
26
- output_type = "string"
27
-
28
- def forward(self, location: str):
29
- # Use a real weather API here
30
- api_key = os.getenv("WEATHER_API_KEY")
31
- if not api_key:
32
- return "Weather API key not found."
33
-
34
- try:
35
- response = requests.get(f"http://api.weatherapi.com/v1/current.json?key={api_key}&q={location}")
36
- response.raise_for_status()
37
- data = response.json()
38
- condition = data["current"]["condition"]["text"]
39
- temp_c = data["current"]["temp_c"]
40
- return f"Weather in {location}: {condition}, {temp_c}°C"
41
- except Exception as e:
42
- return f"Error fetching weather for {location}: {str(e)}"
43
-
44
- class HubStatsTool(Tool):
45
- name = "hub_stats"
46
- description = "Fetches the most downloaded model from a specific author on the Hugging Face Hub."
47
- inputs = {
48
- "author": {
49
- "type": "string",
50
- "description": "The username of the model author/organization to find models from."
51
- }
52
- }
53
- output_type = "string"
54
-
55
- def forward(self, author: str):
56
- try:
57
- # List models from the specified author, sorted by downloads
58
- models = list(list_models(author=author, sort="downloads", direction=-1, limit=1))
59
-
60
- if models:
61
- model = models[0]
62
- return f"The most downloaded model by {author} is {model.id} with {model.downloads:,} downloads."
63
- else:
64
- return f"No models found for author {author}."
65
- except Exception as e:
66
- return f"Error fetching models for {author}: {str(e)}"
67
-
68
- class CalendarTool(Tool):
69
- name = "calendar"
70
- description = "Manages and retrieves information about dates and events."
71
- inputs = {
72
- "action": {
73
- "type": "string",
74
- "description": "The action to perform (e.g., 'add', 'get', 'delete')."
75
- },
76
- "date": {
77
- "type": "string",
78
- "description": "The date of the event (format: YYYY-MM-DD)."
79
- },
80
- "event": {
81
- "type": "string",
82
- "description": "The event description.",
83
- "nullable": True # Add this line to specify that 'event' is nullable
84
- }
85
- }
86
- output_type = "string"
87
-
88
- def __init__(self):
89
- self.events = {}
90
-
91
- def forward(self, action: str, date: str, event: str = None):
92
- if action == "add":
93
- self.events[date] = event
94
- return f"Event '{event}' added to {date}."
95
- elif action == "get":
96
- return f"Event on {date}: {self.events.get(date, 'No event found.')}"
97
- elif action == "delete":
98
- if date in self.events:
99
- del self.events[date]
100
- return f"Event on {date} deleted."
101
- else:
102
- return f"No event found on {date}."
103
- else:
104
- return "Invalid action."
105
-
106
-
107
-
108
- class CalculatorTool(Tool):
109
- name = "calculator"
110
- description = "Performs mathematical calculations."
111
- inputs = {
112
- "expression": {
113
- "type": "string",
114
- "description": "The mathematical expression to evaluate."
115
- }
116
- }
117
- output_type = "string"
118
-
119
- def forward(self, expression: str):
120
- try:
121
- result = eval(expression)
122
- return f"The result of the expression '{expression}' is {result}."
123
- except Exception as e:
124
- return f"Error evaluating expression: {str(e)}"
125
 
126
- class EmailTool(Tool):
127
- name = "email"
128
- description = "Sends and receives emails."
129
- inputs = {
130
- "action": {
131
- "type": "string",
132
- "description": "The action to perform (e.g., 'send')."
133
- },
134
- "to": {
135
- "type": "string",
136
- "description": "The recipient's email address."
137
- },
138
- "subject": {
139
- "type": "string",
140
- "description": "The subject of the email."
141
- },
142
- "body": {
143
- "type": "string",
144
- "description": "The body of the email."
145
- }
146
- }
147
- output_type = "string"
148
-
149
- def __init__(self, smtp_server, smtp_port, email, password):
150
- self.smtp_server = smtp_server
151
- self.smtp_port = smtp_port
152
- self.email = email
153
- self.password = password
154
-
155
- def forward(self, action: str, to: str, subject: str, body: str):
156
- if action == "send":
157
- try:
158
- msg = MIMEText(body)
159
- msg['Subject'] = subject
160
- msg['From'] = self.email
161
- msg['To'] = to
162
-
163
- with smtplib.SMTP(self.smtp_server, self.smtp_port) as server:
164
- server.starttls()
165
- server.login(self.email, self.password)
166
- server.sendmail(self.email, [to], msg.as_string())
167
-
168
- return f"Email sent to {to}."
169
- except Exception as e:
170
- return f"Error sending email: {str(e)}"
171
- else:
172
- return "Invalid action."
173
-
174
- class FileManagementTool(Tool):
175
- name = "file_management"
176
- description = "Handles file operations like reading, writing, and managing files."
177
- inputs = {
178
- "action": {
179
- "type": "string",
180
- "description": "The action to perform (e.g., 'read', 'write', 'delete')."
181
- },
182
- "file_path": {
183
- "type": "string",
184
- "description": "The path of the file."
185
- },
186
- "content": {
187
- "type": "string",
188
- "description": "The content to write to the file.",
189
- "nullable": True # Add this line to specify that 'content' is nullable
190
- }
191
- }
192
- output_type = "string"
193
-
194
- def forward(self, action: str, file_path: str, content: str = None):
195
- if action == "read":
196
- try:
197
- with open(file_path, 'r') as file:
198
- content = file.read()
199
- return f"Content of {file_path}: {content}"
200
- except Exception as e:
201
- return f"Error reading file: {str(e)}"
202
- elif action == "write":
203
- try:
204
- with open(file_path, 'w') as file:
205
- file.write(content)
206
- return f"Content written to {file_path}."
207
- except Exception as e:
208
- return f"Error writing to file: {str(e)}"
209
- elif action == "delete":
210
- try:
211
- os.remove(file_path)
212
- return f"File {file_path} deleted."
213
- except Exception as e:
214
- return f"Error deleting file: {str(e)}"
215
  else:
216
- return "Invalid action."
217
-
218
-
219
- class DatabaseQueryTool(Tool):
220
- name = "database_query"
221
- description = "Interacts with databases for storing and retrieving information."
222
- inputs = {
223
- "action": {
224
- "type": "string",
225
- "description": "The action to perform (e.g., 'query', 'insert')."
226
- },
227
- "query": {
228
- "type": "string",
229
- "description": "The SQL query to execute."
230
- }
231
- }
232
- output_type = "string"
233
-
234
- def __init__(self, db_path):
235
- self.db_path = db_path
236
-
237
- def forward(self, action: str, query: str):
238
- try:
239
- conn = sqlite3.connect(self.db_path)
240
- cursor = conn.cursor()
241
-
242
- if action == "query":
243
- cursor.execute(query)
244
- results = cursor.fetchall()
245
- return f"Query results: {results}"
246
- elif action == "insert":
247
- cursor.execute(query)
248
- conn.commit()
249
- return "Data inserted successfully."
250
- else:
251
- return "Invalid action."
252
-
253
- except Exception as e:
254
- return f"Error executing query: {str(e)}"
255
- finally:
256
- conn.close()
257
-
258
- class TranslationTool(Tool):
259
- name = "translation"
260
- description = "Translates text between different languages."
261
- inputs = {
262
- "text": {
263
- "type": "string",
264
- "description": "The text to translate."
265
- },
266
- "src_lang": {
267
- "type": "string",
268
- "description": "The source language code."
269
- },
270
- "dest_lang": {
271
- "type": "string",
272
- "description": "The destination language code."
273
- }
274
- }
275
- output_type = "string"
276
-
277
- def forward(self, text: str, src_lang: str, dest_lang: str):
278
- try:
279
- translator = Translator()
280
- translation = translator.translate(text, src=src_lang, dest=dest_lang)
281
- return f"Translated text: {translation.text}"
282
- except Exception as e:
283
- return f"Error translating text: {str(e)}"
284
-
285
- class TextToSpeechTool(Tool):
286
- name = "text_to_speech"
287
- description = "Converts text to speech."
288
- inputs = {
289
- "text": {
290
- "type": "string",
291
- "description": "The text to convert to speech."
292
- }
293
- }
294
- output_type = "string"
295
-
296
- def forward(self, text: str):
297
- try:
298
- tts = gTTS(text=text, lang='en')
299
- tts.save("output.mp3")
300
- return "Text converted to speech and saved as output.mp3."
301
- except Exception as e:
302
- return f"Error converting text to speech: {str(e)}"
303
-
304
- class SpeechToTextTool(Tool):
305
- name = "speech_to_text"
306
- description = "Converts speech to text."
307
- inputs = {
308
- "audio_file": {
309
- "type": "string",
310
- "description": "The path to the audio file to convert to text."
311
- }
312
- }
313
- output_type = "string"
314
-
315
- def forward(self, audio_file: str):
316
- try:
317
- recognizer = sr.Recognizer()
318
- with sr.AudioFile(audio_file) as source:
319
- audio = recognizer.record(source)
320
- text = recognizer.recognize_google(audio)
321
- return f"Converted speech to text: {text}"
322
- except Exception as e:
323
- return f"Error converting speech to text: {str(e)}"
324
-
325
- class ImageRecognitionTool(Tool):
326
- name = "image_recognition"
327
- description = "Analyzes and interprets images."
328
- inputs = {
329
- "image_path": {
330
- "type": "string",
331
- "description": "The path to the image to analyze."
332
- }
333
- }
334
- output_type = "string"
335
-
336
- def forward(self, image_path: str):
337
- try:
338
- image = cv2.imread(image_path)
339
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
340
- faces = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml').detectMultiScale(gray, 1.3, 5)
341
- return f"Found {len(faces)} faces in the image."
342
- except Exception as e:
343
- return f"Error analyzing image: {str(e)}"
344
-
345
- class NLPTool(Tool):
346
- name = "nlp"
347
- description = "Performs advanced text processing tasks like sentiment analysis, named entity recognition, etc."
348
- inputs = {
349
- "text": {
350
- "type": "string",
351
- "description": "The text to analyze."
352
- },
353
- "task": {
354
- "type": "string",
355
- "description": "The NLP task to perform (e.g., 'sentiment', 'entities')."
356
- }
357
- }
358
- output_type = "string"
359
-
360
- def forward(self, text: str, task: str):
361
- blob = TextBlob(text)
362
- if task == "sentiment":
363
- sentiment = blob.sentiment
364
- return f"Sentiment analysis: Polarity={sentiment.polarity}, Subjectivity={sentiment.subjectivity}"
365
- elif task == "entities":
366
- entities = blob.noun_phrases
367
- return f"Named entities: {entities}"
368
- else:
369
- return "Invalid task."
370
-
371
- class APIIntegrationTool(Tool):
372
- name = "api_integration"
373
- description = "Interacts with various external APIs for fetching or sending data."
374
- inputs = {
375
- "api_url": {
376
- "type": "string",
377
- "description": "The URL of the API endpoint."
378
- },
379
- "method": {
380
- "type": "string",
381
- "description": "The HTTP method to use (e.g., 'GET', 'POST')."
382
- },
383
- "data": {
384
- "type": "string",
385
- "description": "The data to send with the request.",
386
- "nullable": True # Add this line to specify that 'data' is nullable
387
- }
388
- }
389
- output_type = "string"
390
-
391
- def forward(self, api_url: str, method: str, data: str = None):
392
- try:
393
- if method == "GET":
394
- response = requests.get(api_url)
395
- elif method == "POST":
396
- response = requests.post(api_url, json=data)
397
- else:
398
- return "Invalid method."
399
-
400
- response.raise_for_status()
401
- return f"API response: {response.json()}"
402
- except Exception as e:
403
- return f"Error interacting with API: {str(e)}"
404
-
 
1
+ """LangGraph Agent with CSV-based Vector Store"""
 
 
 
2
  import os
3
+ import ast
4
+ import pandas as pd
 
 
 
5
  import numpy as np
6
+ from sklearn.metrics.pairwise import cosine_similarity
7
+ from dotenv import load_dotenv
8
+ from langgraph.graph import START, StateGraph, MessagesState
9
+ from langgraph.prebuilt import tools_condition, ToolNode
10
+ from langchain_google_genai import ChatGoogleGenerativeAI
11
+ from langchain_groq import ChatGroq
12
+ from langchain_huggingface import ChatHuggingFace, HuggingFaceEndpoint
13
+ from langchain_community.tools.tavily_search import TavilySearchResults
14
+ from langchain_community.document_loaders import WikipediaLoader, ArxivLoader
15
+ from langchain_core.messages import SystemMessage, HumanMessage
16
+ from langchain_core.tools import tool
17
+
18
+ load_dotenv()
19
+
20
+ # Math tools
21
+ @tool
22
+ def multiply(a: int, b: int) -> int:
23
+ """Multiply two numbers."""
24
+ return a * b
25
+
26
+ @tool
27
+ def add(a: int, b: int) -> int:
28
+ """Add two numbers."""
29
+ return a + b
30
+
31
+ @tool
32
+ def subtract(a: int, b: int) -> int:
33
+ """Subtract two numbers."""
34
+ return a - b
35
+
36
+ @tool
37
+ def divide(a: int, b: int) -> float:
38
+ """Divide two numbers."""
39
+ if b == 0:
40
+ raise ValueError("Cannot divide by zero.")
41
+ return a / b
42
+
43
+ @tool
44
+ def modulus(a: int, b: int) -> int:
45
+ """Get the modulus of two numbers."""
46
+ return a % b
47
+
48
+ # Search tools
49
+ @tool
50
+ def wiki_search(query: str) -> str:
51
+ """Search Wikipedia for a query and return maximum 2 results."""
52
+ search_docs = WikipediaLoader(query=query, load_max_docs=2).load()
53
+ formatted_search_docs = "\n\n---\n\n".join(
54
+ [f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content}\n</Document>'
55
+ for doc in search_docs])
56
+ return formatted_search_docs
57
+
58
+ @tool
59
+ def web_search(query: str) -> str:
60
+ """Search Tavily for a query and return maximum 3 results."""
61
+ search_docs = TavilySearchResults(max_results=3).invoke(query=query)
62
+ formatted_search_docs = "\n\n---\n\n".join(
63
+ [f'<Document source="{doc.get("url", "")}" title="{doc.get("title", "")}"/>\n{doc.get("content", "")}\n</Document>'
64
+ for doc in search_docs])
65
+ return formatted_search_docs
66
+
67
+ @tool
68
+ def arxiv_search(query: str) -> str:
69
+ """Search Arxiv for a query and return maximum 3 results."""
70
+ search_docs = ArxivLoader(query=query, load_max_docs=3).load()
71
+ formatted_search_docs = "\n\n---\n\n".join(
72
+ [f'<Document source="{doc.metadata["source"]}" page="{doc.metadata.get("page", "")}"/>\n{doc.page_content[:1000]}\n</Document>'
73
+ for doc in search_docs])
74
+ return formatted_search_docs
75
+
76
+ # CSV-based Vector Store Class
77
+ class CSVVectorStore:
78
+ def __init__(self, csv_file_path: str):
79
+ """Initialize the CSV vector store."""
80
+ self.df = pd.read_csv(csv_file_path)
81
+ # Convert string representation of embeddings to numpy arrays
82
+ self.df['embedding'] = self.df['embedding'].apply(ast.literal_eval)
83
+ self.embeddings_matrix = np.array(self.df['embedding'].tolist())
84
+
85
+ def similarity_search(self, query_embedding: np.ndarray, k: int = 1):
86
+ """Find most similar documents to the query embedding."""
87
+ # Calculate cosine similarity
88
+ similarities = cosine_similarity([query_embedding], self.embeddings_matrix)[0]
89
+
90
+ # Get top k indices
91
+ top_indices = np.argsort(similarities)[-k:][::-1]
92
+
93
+ # Return results in a format similar to LangChain's Document
94
+ results = []
95
+ for idx in top_indices:
96
+ class Document:
97
+ def __init__(self, page_content, metadata):
98
+ self.page_content = page_content
99
+ self.metadata = metadata
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
101
+ doc = Document(
102
+ page_content=self.df.iloc[idx]['content'],
103
+ metadata=ast.literal_eval(self.df.iloc[idx]['metadata']) if isinstance(self.df.iloc[idx]['metadata'], str) else self.df.iloc[idx]['metadata']
104
+ )
105
+ results.append(doc)
106
+
107
+ return results
108
+
109
+ # System prompt
110
+ system_prompt = """You are a helpful assistant tasked with answering questions using a set of tools. Now, I will ask you a question. Report your thoughts, and finish your answer with the following template: FINAL ANSWER: [YOUR FINAL ANSWER]. YOUR FINAL ANSWER should be a number OR as few words as possible OR a comma separated list of numbers and/or strings. If you are asked for a number, do not use comma to write your number neither use units such as $ or percent sign unless specified otherwise. If you are asked for a string, do not use articles, neither abbreviations (e.g. for cities), and write the digits in plain text unless specified otherwise. If you are asked for a comma separated list, apply the above rules depending of whether the element to be put in the list is a number or a string. Your answer should only start with 'FINAL ANSWER: ', then follows with the answer."""
111
+
112
+ # Tools list
113
+ tools = [
114
+ multiply,
115
+ add,
116
+ subtract,
117
+ divide,
118
+ modulus,
119
+ wiki_search,
120
+ web_search,
121
+ arxiv_search,
122
+ ]
123
+
124
+ def build_graph(provider: str = "groq", csv_file_path: str = "embeddings.csv"):
125
+ """Build the graph with CSV-based vector store."""
126
+
127
+ # Initialize CSV vector store
128
+ vector_store = CSVVectorStore(csv_file_path)
129
+
130
+ # System message
131
+ sys_msg = SystemMessage(content=system_prompt)
132
+
133
+ # Initialize LLM based on provider
134
+ if provider == "google":
135
+ llm = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0)
136
+ elif provider == "groq":
137
+ llm = ChatGroq(model="qwen-qwq-32b", temperature=0)
138
+ elif provider == "huggingface":
139
+ llm = ChatHuggingFace(
140
+ llm=HuggingFaceEndpoint(
141
+ url="https://api-inference.huggingface.co/models/Meta-DeepLearning/llama-2-7b-chat-hf",
142
+ temperature=0,
143
+ ),
144
+ )
145
+ else:
146
+ raise ValueError("Invalid provider. Choose 'google', 'groq' or 'huggingface'.")
147
+
148
+ # Bind tools to LLM
149
+ llm_with_tools = llm.bind_tools(tools)
150
+
151
+ # Helper function to get query embedding (simplified - you might want to use the same embedding model)
152
+ def get_query_embedding(query: str) -> np.ndarray:
153
+ # For now, return a random embedding - in practice, you'd use the same embedding model
154
+ # that was used to create the CSV embeddings
155
+ return np.random.rand(768) # Assuming 768-dim embeddings
156
+
157
+ # Nodes
158
+ def assistant(state: MessagesState):
159
+ """Assistant node."""
160
+ return {"messages": [llm_with_tools.invoke(state["messages"])]}
161
+
162
+ def retriever(state: MessagesState):
163
+ """Retriever node using CSV vector store."""
164
+ query = state["messages"][-1].content if state["messages"] else ""
165
+
166
+ # Get query embedding (this is simplified - you'd use proper embedding model)
167
+ query_embedding = get_query_embedding(query)
168
+
169
+ # Search for similar documents
170
+ similar_docs = vector_store.similarity_search(query_embedding, k=1)
171
+
172
+ if similar_docs:
173
+ example_msg = HumanMessage(
174
+ content=f"Here I provide a similar question and answer for reference: \n\n{similar_docs[0].page_content}",
175
+ )
176
+ return {"messages": [sys_msg] + state["messages"] + [example_msg]}
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  else:
178
+ return {"messages": [sys_msg] + state["messages"]}
179
+
180
+ # Build graph
181
+ builder = StateGraph(MessagesState)
182
+ builder.add_node("retriever", retriever)
183
+ builder.add_node("assistant", assistant)
184
+ builder.add_node("tools", ToolNode(tools))
185
+
186
+ builder.add_edge(START, "retriever")
187
+ builder.add_edge("retriever", "assistant")
188
+ builder.add_conditional_edges("assistant", tools_condition)
189
+ builder.add_edge("tools", "assistant")
190
+
191
+ return builder.compile()
192
+
193
+ # Test
194
+ if __name__ == "__main__":
195
+ question = "When was a picture of St. Thomas Aquinas first added to the Wikipedia page on the Principle of double effect?"
196
+
197
+ # Build the graph (you'll need to provide the path to your CSV file)
198
+ graph = build_graph(provider="groq", csv_file_path="your_embeddings.csv")
199
+
200
+ # Run the graph
201
+ messages = [HumanMessage(content=question)]
202
+ messages = graph.invoke({"messages": messages})
203
+
204
+ for m in messages["messages"]:
205
+ m.pretty_print()