abtsousa commited on
Commit
0242ef6
·
1 Parent(s): ffc544c

Update configuration and enhance tool functionality

Browse files

- Set API_BASE_URL, API_KEY_ENV_VAR, and MODEL_NAME in config.py
- Import ChatOllama in nodes.py and adjust model fetching logic
- Ignore type checking for process_questions in app.py
- Add langchain-ollama dependency in pyproject.toml and uv.lock
- Refactor tool imports in __init__.py and enhance message trimming in wikipedia.py

Files changed (7) hide show
  1. agent/config.py +7 -6
  2. agent/nodes.py +29 -4
  3. app.py +1 -1
  4. pyproject.toml +1 -0
  5. tools/__init__.py +5 -10
  6. tools/wikipedia.py +29 -3
  7. uv.lock +28 -0
agent/config.py CHANGED
@@ -1,12 +1,13 @@
1
  from typing import Literal
2
 
3
- #API_BASE_URL = "https://openrouter.ai/api/v1"
4
- #MODEL_NAME = "openai/gpt-oss-120b:floor"
5
- #API_KEY_ENV_VAR = "OPENROUTER_API_KEY"
6
 
7
- API_BASE_URL = "https://api.openai.com/v1/"
8
- MODEL_NAME = "gpt-5"
9
- API_KEY_ENV_VAR = "OPENAI_API_KEY_ORACLEBOT"
 
10
 
11
  MODEL_TEMPERATURE = 0.7
12
 
 
1
  from typing import Literal
2
 
3
+ API_BASE_URL = "https://openrouter.ai/api/v1"
4
+ API_KEY_ENV_VAR = "OPENROUTER_API_KEY"
5
+ MODEL_NAME = "google/gemini-2.5-pro"
6
 
7
+ #API_BASE_URL = "https://api.openai.com/v1/"
8
+ #MODEL_NAME = "gpt-5"
9
+ #API_KEY_ENV_VAR = "OPENAI_API_KEY_ORACLEBOT"
10
+ MAX_TOKENS = None
11
 
12
  MODEL_TEMPERATURE = 0.7
13
 
agent/nodes.py CHANGED
@@ -9,6 +9,7 @@ from langchain_core.language_models.base import LanguageModelInput
9
  from langchain_google_genai import ChatGoogleGenerativeAI
10
  from langchain_openai import ChatOpenAI
11
  from langchain_deepseek import ChatDeepSeek
 
12
  from pydantic import BaseModel, Field, SecretStr
13
  from agent.prompts import get_system_prompt
14
  from agent.state import State
@@ -17,8 +18,9 @@ from langgraph.prebuilt import ToolNode
17
  import backoff
18
  import openai
19
  import re
 
20
 
21
- from agent.config import API_BASE_URL, MODEL_NAME, API_KEY_ENV_VAR, MODEL_TEMPERATURE
22
 
23
  from dotenv import load_dotenv
24
  load_dotenv()
@@ -39,6 +41,16 @@ def _get_model() -> BaseChatModel:
39
 
40
  api_key = os.getenv(API_KEY_ENV_VAR)
41
 
 
 
 
 
 
 
 
 
 
 
42
  return ChatOpenAI(
43
  api_key=SecretStr(api_key) if api_key else None,
44
  base_url=API_BASE_URL,
@@ -68,12 +80,25 @@ def _bind_model(model: BaseChatModel) -> Runnable[LanguageModelInput, BaseMessag
68
  # Call model node
69
  @backoff.on_exception(
70
  backoff.runtime,
71
- openai.RateLimitError,
72
  value=lambda e: float(match.group(1)) if (match := re.search(r'try again in (\d+(?:\.\d+)?)s', str(e))) else 10.0,
73
- max_tries=20,
74
  )
 
75
  def call_model(state: State, config) -> dict[str, list[BaseMessage]]:
76
- messages = state["messages"]
 
 
 
 
 
 
 
 
 
 
 
 
77
  app_name = config.get('configurable', {}).get("app_name", "OracleBot")
78
 
79
  # Add system prompt if not already present
 
9
  from langchain_google_genai import ChatGoogleGenerativeAI
10
  from langchain_openai import ChatOpenAI
11
  from langchain_deepseek import ChatDeepSeek
12
+ from langchain_ollama import ChatOllama
13
  from pydantic import BaseModel, Field, SecretStr
14
  from agent.prompts import get_system_prompt
15
  from agent.state import State
 
18
  import backoff
19
  import openai
20
  import re
21
+ from langchain_core.messages.utils import trim_messages, count_tokens_approximately
22
 
23
+ from agent.config import API_BASE_URL, MAX_TOKENS, MODEL_NAME, API_KEY_ENV_VAR, MODEL_TEMPERATURE
24
 
25
  from dotenv import load_dotenv
26
  load_dotenv()
 
41
 
42
  api_key = os.getenv(API_KEY_ENV_VAR)
43
 
44
+ # return ChatOllama(
45
+ # model=MODEL_NAME,
46
+ # temperature=MODEL_TEMPERATURE if MODEL_TEMPERATURE else 0.0,
47
+ # metadata={
48
+ # "reasoning": {
49
+ # "effort": "high" # Use high reasoning effort
50
+ # }
51
+ # }
52
+ # )
53
+
54
  return ChatOpenAI(
55
  api_key=SecretStr(api_key) if api_key else None,
56
  base_url=API_BASE_URL,
 
80
  # Call model node
81
  @backoff.on_exception(
82
  backoff.runtime,
83
+ (openai.RateLimitError, openai.InternalServerError),
84
  value=lambda e: float(match.group(1)) if (match := re.search(r'try again in (\d+(?:\.\d+)?)s', str(e))) else 10.0,
85
+ max_tries=200,
86
  )
87
+
88
  def call_model(state: State, config) -> dict[str, list[BaseMessage]]:
89
+ if MAX_TOKENS:
90
+ messages = trim_messages(
91
+ state["messages"],
92
+ strategy="last",
93
+ token_counter=count_tokens_approximately,
94
+ allow_partial=True,
95
+ max_tokens=MAX_TOKENS,
96
+ start_on="human",
97
+ end_on=("human", "tool"),
98
+ )
99
+ else:
100
+ messages = state["messages"]
101
+
102
  app_name = config.get('configurable', {}).get("app_name", "OracleBot")
103
 
104
  # Add system prompt if not already present
app.py CHANGED
@@ -119,7 +119,7 @@ async def run_and_submit_all( profile: gr.OAuthProfile | None):
119
  # Use the cache directory for this session
120
  working_dir = CACHE_DIR
121
 
122
- results_log, answers_payload = await process_questions(agent, questions_data, working_dir)
123
 
124
  # Remove everything before "FINAL ANSWER: " in submitted answers
125
  for answer in answers_payload:
 
119
  # Use the cache directory for this session
120
  working_dir = CACHE_DIR
121
 
122
+ results_log, answers_payload = await process_questions(agent, questions_data, working_dir) # type: ignore
123
 
124
  # Remove everything before "FINAL ANSWER: " in submitted answers
125
  for answer in answers_payload:
pyproject.toml CHANGED
@@ -15,6 +15,7 @@ dependencies = [
15
  "langchain-community>=0.3.27",
16
  "langchain-deepseek>=0.1.4",
17
  "langchain-google-genai>=2.1.9",
 
18
  "langchain-tavily>=0.2.11",
19
  "langchain[google-genai,googlegenai,openai]>=0.3.26",
20
  "langgraph>=0.4.8",
 
15
  "langchain-community>=0.3.27",
16
  "langchain-deepseek>=0.1.4",
17
  "langchain-google-genai>=2.1.9",
18
+ "langchain-ollama>=0.3.6",
19
  "langchain-tavily>=0.2.11",
20
  "langchain[google-genai,googlegenai,openai]>=0.3.26",
21
  "langgraph>=0.4.8",
tools/__init__.py CHANGED
@@ -1,8 +1,6 @@
1
  from .wikipedia import wiki_fetch_article, wiki_parse_html
2
  from .search import web_search
3
  from .code_interpreter import execute_code_multilang
4
- from .files import *
5
- from .calculator import *
6
  from langchain_core.tools import BaseTool
7
  from . import calculator, files
8
 
@@ -17,17 +15,14 @@ def get_all_tools() -> list[BaseTool]:
17
  wiki_fetch_article,
18
  wiki_parse_html,
19
  web_search,
20
- execute_code_multilang,
21
  ]
22
 
23
- # Automatically add all calculator functions
24
- for name in calculator.__all__:
25
- tools.append(getattr(calculator, name))
26
-
27
- # Automatically add all file management functions
28
- for name in files.__all__:
29
- tools.append(getattr(files, name))
30
 
 
31
  return tools
32
 
33
  def list_tools() -> str:
 
1
  from .wikipedia import wiki_fetch_article, wiki_parse_html
2
  from .search import web_search
3
  from .code_interpreter import execute_code_multilang
 
 
4
  from langchain_core.tools import BaseTool
5
  from . import calculator, files
6
 
 
15
  wiki_fetch_article,
16
  wiki_parse_html,
17
  web_search,
18
+ #execute_code_multilang, # Disabled for the repo, enable at your own risk
19
  ]
20
 
21
+ # Automatically add all functions for the remaining tools
22
+ tools.extend([getattr(calculator, name) for name in calculator.__all__])
23
+ tools.extend([getattr(files, name) for name in files.__all__])
 
 
 
 
24
 
25
+ #return [] # Disable all tools
26
  return tools
27
 
28
  def list_tools() -> str:
tools/wikipedia.py CHANGED
@@ -2,6 +2,9 @@ from langchain_core.tools import tool
2
  import wikipediaapi
3
  import requests
4
  from bs4 import BeautifulSoup
 
 
 
5
 
6
  @tool
7
  def wiki_fetch_article(article_title: str) -> str:
@@ -86,11 +89,34 @@ def wiki_parse_html(page_title: str, section_id: int | None = None) -> str:
86
 
87
  # Optional: collapse excessive whitespace
88
  text = str(soup)
89
- return text
 
 
 
 
 
 
 
 
 
 
 
 
90
  except Exception as e:
91
  # Fallback to raw HTML if sanitization fails
92
- return raw_html
93
-
 
 
 
 
 
 
 
 
 
 
 
94
  except requests.RequestException as e:
95
  return f"Error fetching page: {str(e)}"
96
  except Exception as e:
 
2
  import wikipediaapi
3
  import requests
4
  from bs4 import BeautifulSoup
5
+ from langchain_core.messages.utils import count_tokens_approximately, trim_messages
6
+ from langchain_core.messages import HumanMessage
7
+ from agent.config import MAX_TOKENS
8
 
9
  @tool
10
  def wiki_fetch_article(article_title: str) -> str:
 
89
 
90
  # Optional: collapse excessive whitespace
91
  text = str(soup)
92
+
93
+ if MAX_TOKENS:
94
+ # Use trim_messages to fit max tokens
95
+ messages = [HumanMessage(content=text)]
96
+ trimmed_messages = trim_messages(
97
+ messages,
98
+ strategy="last",
99
+ token_counter=count_tokens_approximately,
100
+ allow_partial=True,
101
+ max_tokens=MAX_TOKENS,
102
+ )
103
+
104
+ return trimmed_messages[0].content if trimmed_messages else text
105
  except Exception as e:
106
  # Fallback to raw HTML if sanitization fails
107
+ messages = [HumanMessage(content=raw_html)]
108
+
109
+ if MAX_TOKENS:
110
+ # Use trim_messages to fit max tokens
111
+ trimmed_messages = trim_messages(
112
+ messages,
113
+ strategy="last",
114
+ token_counter=count_tokens_approximately,
115
+ allow_partial=True,
116
+ max_tokens=MAX_TOKENS,
117
+ )
118
+ return trimmed_messages[0].content if trimmed_messages else text
119
+
120
  except requests.RequestException as e:
121
  return f"Error fetching page: {str(e)}"
122
  except Exception as e:
uv.lock CHANGED
@@ -1289,6 +1289,19 @@ wheels = [
1289
  { url = "https://files.pythonhosted.org/packages/84/d8/e1162835d5d6eefaae341c2d1cf750ab53222a421252346905187e53b8a2/langchain_google_genai-2.1.9-py3-none-any.whl", hash = "sha256:8d3aab59706b8f8920a22bcfd63c5000ce430fe61db6ecdec262977d1a0be5b8", size = 49381, upload-time = "2025-08-04T18:51:50.51Z" },
1290
  ]
1291
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1292
  [[package]]
1293
  name = "langchain-openai"
1294
  version = "0.3.29"
@@ -1648,6 +1661,19 @@ wheels = [
1648
  { url = "https://files.pythonhosted.org/packages/67/0e/35082d13c09c02c011cf21570543d202ad929d961c02a147493cb0c2bdf5/numpy-2.2.6-cp313-cp313t-win_amd64.whl", hash = "sha256:6031dd6dfecc0cf9f668681a37648373bddd6421fff6c66ec1624eed0180ee06", size = 12771374, upload-time = "2025-05-17T21:43:35.479Z" },
1649
  ]
1650
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1651
  [[package]]
1652
  name = "openai"
1653
  version = "1.99.6"
@@ -1851,6 +1877,7 @@ dependencies = [
1851
  { name = "langchain-community" },
1852
  { name = "langchain-deepseek" },
1853
  { name = "langchain-google-genai" },
 
1854
  { name = "langchain-tavily" },
1855
  { name = "langgraph" },
1856
  { name = "matplotlib" },
@@ -1887,6 +1914,7 @@ requires-dist = [
1887
  { name = "langchain-community", specifier = ">=0.3.27" },
1888
  { name = "langchain-deepseek", specifier = ">=0.1.4" },
1889
  { name = "langchain-google-genai", specifier = ">=2.1.9" },
 
1890
  { name = "langchain-tavily", specifier = ">=0.2.11" },
1891
  { name = "langgraph", specifier = ">=0.4.8" },
1892
  { name = "matplotlib", specifier = ">=3.10.5" },
 
1289
  { url = "https://files.pythonhosted.org/packages/84/d8/e1162835d5d6eefaae341c2d1cf750ab53222a421252346905187e53b8a2/langchain_google_genai-2.1.9-py3-none-any.whl", hash = "sha256:8d3aab59706b8f8920a22bcfd63c5000ce430fe61db6ecdec262977d1a0be5b8", size = 49381, upload-time = "2025-08-04T18:51:50.51Z" },
1290
  ]
1291
 
1292
+ [[package]]
1293
+ name = "langchain-ollama"
1294
+ version = "0.3.6"
1295
+ source = { registry = "https://pypi.org/simple" }
1296
+ dependencies = [
1297
+ { name = "langchain-core" },
1298
+ { name = "ollama" },
1299
+ ]
1300
+ sdist = { url = "https://files.pythonhosted.org/packages/82/67/93429a78d6fd40e2addf27e881db37e7f0076d712ffe9759ca0d5e10910e/langchain_ollama-0.3.6.tar.gz", hash = "sha256:4270c4b30b3f3d10850cb9a1183b8c77d616195e0d9717ac745ef7f7f6cc2b6e", size = 30479, upload-time = "2025-07-22T17:26:59.605Z" }
1301
+ wheels = [
1302
+ { url = "https://files.pythonhosted.org/packages/f3/c5/1e559f5b43d62850ea2b44097afc944f38894eac00e7feef3b42f0428916/langchain_ollama-0.3.6-py3-none-any.whl", hash = "sha256:b339bd3fcf913b8d606ad426ef39e7122695532507fcd85aa96271b3f33dc3df", size = 24535, upload-time = "2025-07-22T17:26:58.556Z" },
1303
+ ]
1304
+
1305
  [[package]]
1306
  name = "langchain-openai"
1307
  version = "0.3.29"
 
1661
  { url = "https://files.pythonhosted.org/packages/67/0e/35082d13c09c02c011cf21570543d202ad929d961c02a147493cb0c2bdf5/numpy-2.2.6-cp313-cp313t-win_amd64.whl", hash = "sha256:6031dd6dfecc0cf9f668681a37648373bddd6421fff6c66ec1624eed0180ee06", size = 12771374, upload-time = "2025-05-17T21:43:35.479Z" },
1662
  ]
1663
 
1664
+ [[package]]
1665
+ name = "ollama"
1666
+ version = "0.5.3"
1667
+ source = { registry = "https://pypi.org/simple" }
1668
+ dependencies = [
1669
+ { name = "httpx" },
1670
+ { name = "pydantic" },
1671
+ ]
1672
+ sdist = { url = "https://files.pythonhosted.org/packages/91/6d/ae96027416dcc2e98c944c050c492789502d7d7c0b95a740f0bb39268632/ollama-0.5.3.tar.gz", hash = "sha256:40b6dff729df3b24e56d4042fd9d37e231cee8e528677e0d085413a1d6692394", size = 43331, upload-time = "2025-08-07T21:44:10.422Z" }
1673
+ wheels = [
1674
+ { url = "https://files.pythonhosted.org/packages/be/f6/2091e50b8b6c3e6901f6eab283d5efd66fb71c86ddb1b4d68766c3eeba0f/ollama-0.5.3-py3-none-any.whl", hash = "sha256:a8303b413d99a9043dbf77ebf11ced672396b59bec27e6d5db67c88f01b279d2", size = 13490, upload-time = "2025-08-07T21:44:09.353Z" },
1675
+ ]
1676
+
1677
  [[package]]
1678
  name = "openai"
1679
  version = "1.99.6"
 
1877
  { name = "langchain-community" },
1878
  { name = "langchain-deepseek" },
1879
  { name = "langchain-google-genai" },
1880
+ { name = "langchain-ollama" },
1881
  { name = "langchain-tavily" },
1882
  { name = "langgraph" },
1883
  { name = "matplotlib" },
 
1914
  { name = "langchain-community", specifier = ">=0.3.27" },
1915
  { name = "langchain-deepseek", specifier = ">=0.1.4" },
1916
  { name = "langchain-google-genai", specifier = ">=2.1.9" },
1917
+ { name = "langchain-ollama", specifier = ">=0.3.6" },
1918
  { name = "langchain-tavily", specifier = ">=0.2.11" },
1919
  { name = "langgraph", specifier = ">=0.4.8" },
1920
  { name = "matplotlib", specifier = ">=3.10.5" },