Spaces:
Runtime error
Runtime error
abtsousa
commited on
Commit
·
0242ef6
1
Parent(s):
ffc544c
Update configuration and enhance tool functionality
Browse files- Set API_BASE_URL, API_KEY_ENV_VAR, and MODEL_NAME in config.py
- Import ChatOllama in nodes.py and adjust model fetching logic
- Ignore type checking for process_questions in app.py
- Add langchain-ollama dependency in pyproject.toml and uv.lock
- Refactor tool imports in __init__.py and enhance message trimming in wikipedia.py
- agent/config.py +7 -6
- agent/nodes.py +29 -4
- app.py +1 -1
- pyproject.toml +1 -0
- tools/__init__.py +5 -10
- tools/wikipedia.py +29 -3
- uv.lock +28 -0
agent/config.py
CHANGED
|
@@ -1,12 +1,13 @@
|
|
| 1 |
from typing import Literal
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
|
| 7 |
-
API_BASE_URL = "https://api.openai.com/v1/"
|
| 8 |
-
MODEL_NAME = "gpt-5"
|
| 9 |
-
API_KEY_ENV_VAR = "OPENAI_API_KEY_ORACLEBOT"
|
|
|
|
| 10 |
|
| 11 |
MODEL_TEMPERATURE = 0.7
|
| 12 |
|
|
|
|
| 1 |
from typing import Literal
|
| 2 |
|
| 3 |
+
API_BASE_URL = "https://openrouter.ai/api/v1"
|
| 4 |
+
API_KEY_ENV_VAR = "OPENROUTER_API_KEY"
|
| 5 |
+
MODEL_NAME = "google/gemini-2.5-pro"
|
| 6 |
|
| 7 |
+
#API_BASE_URL = "https://api.openai.com/v1/"
|
| 8 |
+
#MODEL_NAME = "gpt-5"
|
| 9 |
+
#API_KEY_ENV_VAR = "OPENAI_API_KEY_ORACLEBOT"
|
| 10 |
+
MAX_TOKENS = None
|
| 11 |
|
| 12 |
MODEL_TEMPERATURE = 0.7
|
| 13 |
|
agent/nodes.py
CHANGED
|
@@ -9,6 +9,7 @@ from langchain_core.language_models.base import LanguageModelInput
|
|
| 9 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 10 |
from langchain_openai import ChatOpenAI
|
| 11 |
from langchain_deepseek import ChatDeepSeek
|
|
|
|
| 12 |
from pydantic import BaseModel, Field, SecretStr
|
| 13 |
from agent.prompts import get_system_prompt
|
| 14 |
from agent.state import State
|
|
@@ -17,8 +18,9 @@ from langgraph.prebuilt import ToolNode
|
|
| 17 |
import backoff
|
| 18 |
import openai
|
| 19 |
import re
|
|
|
|
| 20 |
|
| 21 |
-
from agent.config import API_BASE_URL, MODEL_NAME, API_KEY_ENV_VAR, MODEL_TEMPERATURE
|
| 22 |
|
| 23 |
from dotenv import load_dotenv
|
| 24 |
load_dotenv()
|
|
@@ -39,6 +41,16 @@ def _get_model() -> BaseChatModel:
|
|
| 39 |
|
| 40 |
api_key = os.getenv(API_KEY_ENV_VAR)
|
| 41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
return ChatOpenAI(
|
| 43 |
api_key=SecretStr(api_key) if api_key else None,
|
| 44 |
base_url=API_BASE_URL,
|
|
@@ -68,12 +80,25 @@ def _bind_model(model: BaseChatModel) -> Runnable[LanguageModelInput, BaseMessag
|
|
| 68 |
# Call model node
|
| 69 |
@backoff.on_exception(
|
| 70 |
backoff.runtime,
|
| 71 |
-
openai.RateLimitError,
|
| 72 |
value=lambda e: float(match.group(1)) if (match := re.search(r'try again in (\d+(?:\.\d+)?)s', str(e))) else 10.0,
|
| 73 |
-
max_tries=
|
| 74 |
)
|
|
|
|
| 75 |
def call_model(state: State, config) -> dict[str, list[BaseMessage]]:
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
app_name = config.get('configurable', {}).get("app_name", "OracleBot")
|
| 78 |
|
| 79 |
# Add system prompt if not already present
|
|
|
|
| 9 |
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 10 |
from langchain_openai import ChatOpenAI
|
| 11 |
from langchain_deepseek import ChatDeepSeek
|
| 12 |
+
from langchain_ollama import ChatOllama
|
| 13 |
from pydantic import BaseModel, Field, SecretStr
|
| 14 |
from agent.prompts import get_system_prompt
|
| 15 |
from agent.state import State
|
|
|
|
| 18 |
import backoff
|
| 19 |
import openai
|
| 20 |
import re
|
| 21 |
+
from langchain_core.messages.utils import trim_messages, count_tokens_approximately
|
| 22 |
|
| 23 |
+
from agent.config import API_BASE_URL, MAX_TOKENS, MODEL_NAME, API_KEY_ENV_VAR, MODEL_TEMPERATURE
|
| 24 |
|
| 25 |
from dotenv import load_dotenv
|
| 26 |
load_dotenv()
|
|
|
|
| 41 |
|
| 42 |
api_key = os.getenv(API_KEY_ENV_VAR)
|
| 43 |
|
| 44 |
+
# return ChatOllama(
|
| 45 |
+
# model=MODEL_NAME,
|
| 46 |
+
# temperature=MODEL_TEMPERATURE if MODEL_TEMPERATURE else 0.0,
|
| 47 |
+
# metadata={
|
| 48 |
+
# "reasoning": {
|
| 49 |
+
# "effort": "high" # Use high reasoning effort
|
| 50 |
+
# }
|
| 51 |
+
# }
|
| 52 |
+
# )
|
| 53 |
+
|
| 54 |
return ChatOpenAI(
|
| 55 |
api_key=SecretStr(api_key) if api_key else None,
|
| 56 |
base_url=API_BASE_URL,
|
|
|
|
| 80 |
# Call model node
|
| 81 |
@backoff.on_exception(
|
| 82 |
backoff.runtime,
|
| 83 |
+
(openai.RateLimitError, openai.InternalServerError),
|
| 84 |
value=lambda e: float(match.group(1)) if (match := re.search(r'try again in (\d+(?:\.\d+)?)s', str(e))) else 10.0,
|
| 85 |
+
max_tries=200,
|
| 86 |
)
|
| 87 |
+
|
| 88 |
def call_model(state: State, config) -> dict[str, list[BaseMessage]]:
|
| 89 |
+
if MAX_TOKENS:
|
| 90 |
+
messages = trim_messages(
|
| 91 |
+
state["messages"],
|
| 92 |
+
strategy="last",
|
| 93 |
+
token_counter=count_tokens_approximately,
|
| 94 |
+
allow_partial=True,
|
| 95 |
+
max_tokens=MAX_TOKENS,
|
| 96 |
+
start_on="human",
|
| 97 |
+
end_on=("human", "tool"),
|
| 98 |
+
)
|
| 99 |
+
else:
|
| 100 |
+
messages = state["messages"]
|
| 101 |
+
|
| 102 |
app_name = config.get('configurable', {}).get("app_name", "OracleBot")
|
| 103 |
|
| 104 |
# Add system prompt if not already present
|
app.py
CHANGED
|
@@ -119,7 +119,7 @@ async def run_and_submit_all( profile: gr.OAuthProfile | None):
|
|
| 119 |
# Use the cache directory for this session
|
| 120 |
working_dir = CACHE_DIR
|
| 121 |
|
| 122 |
-
results_log, answers_payload = await process_questions(agent, questions_data, working_dir)
|
| 123 |
|
| 124 |
# Remove everything before "FINAL ANSWER: " in submitted answers
|
| 125 |
for answer in answers_payload:
|
|
|
|
| 119 |
# Use the cache directory for this session
|
| 120 |
working_dir = CACHE_DIR
|
| 121 |
|
| 122 |
+
results_log, answers_payload = await process_questions(agent, questions_data, working_dir) # type: ignore
|
| 123 |
|
| 124 |
# Remove everything before "FINAL ANSWER: " in submitted answers
|
| 125 |
for answer in answers_payload:
|
pyproject.toml
CHANGED
|
@@ -15,6 +15,7 @@ dependencies = [
|
|
| 15 |
"langchain-community>=0.3.27",
|
| 16 |
"langchain-deepseek>=0.1.4",
|
| 17 |
"langchain-google-genai>=2.1.9",
|
|
|
|
| 18 |
"langchain-tavily>=0.2.11",
|
| 19 |
"langchain[google-genai,googlegenai,openai]>=0.3.26",
|
| 20 |
"langgraph>=0.4.8",
|
|
|
|
| 15 |
"langchain-community>=0.3.27",
|
| 16 |
"langchain-deepseek>=0.1.4",
|
| 17 |
"langchain-google-genai>=2.1.9",
|
| 18 |
+
"langchain-ollama>=0.3.6",
|
| 19 |
"langchain-tavily>=0.2.11",
|
| 20 |
"langchain[google-genai,googlegenai,openai]>=0.3.26",
|
| 21 |
"langgraph>=0.4.8",
|
tools/__init__.py
CHANGED
|
@@ -1,8 +1,6 @@
|
|
| 1 |
from .wikipedia import wiki_fetch_article, wiki_parse_html
|
| 2 |
from .search import web_search
|
| 3 |
from .code_interpreter import execute_code_multilang
|
| 4 |
-
from .files import *
|
| 5 |
-
from .calculator import *
|
| 6 |
from langchain_core.tools import BaseTool
|
| 7 |
from . import calculator, files
|
| 8 |
|
|
@@ -17,17 +15,14 @@ def get_all_tools() -> list[BaseTool]:
|
|
| 17 |
wiki_fetch_article,
|
| 18 |
wiki_parse_html,
|
| 19 |
web_search,
|
| 20 |
-
execute_code_multilang,
|
| 21 |
]
|
| 22 |
|
| 23 |
-
# Automatically add all
|
| 24 |
-
for name in calculator.__all__
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
# Automatically add all file management functions
|
| 28 |
-
for name in files.__all__:
|
| 29 |
-
tools.append(getattr(files, name))
|
| 30 |
|
|
|
|
| 31 |
return tools
|
| 32 |
|
| 33 |
def list_tools() -> str:
|
|
|
|
| 1 |
from .wikipedia import wiki_fetch_article, wiki_parse_html
|
| 2 |
from .search import web_search
|
| 3 |
from .code_interpreter import execute_code_multilang
|
|
|
|
|
|
|
| 4 |
from langchain_core.tools import BaseTool
|
| 5 |
from . import calculator, files
|
| 6 |
|
|
|
|
| 15 |
wiki_fetch_article,
|
| 16 |
wiki_parse_html,
|
| 17 |
web_search,
|
| 18 |
+
#execute_code_multilang, # Disabled for the repo, enable at your own risk
|
| 19 |
]
|
| 20 |
|
| 21 |
+
# Automatically add all functions for the remaining tools
|
| 22 |
+
tools.extend([getattr(calculator, name) for name in calculator.__all__])
|
| 23 |
+
tools.extend([getattr(files, name) for name in files.__all__])
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
#return [] # Disable all tools
|
| 26 |
return tools
|
| 27 |
|
| 28 |
def list_tools() -> str:
|
tools/wikipedia.py
CHANGED
|
@@ -2,6 +2,9 @@ from langchain_core.tools import tool
|
|
| 2 |
import wikipediaapi
|
| 3 |
import requests
|
| 4 |
from bs4 import BeautifulSoup
|
|
|
|
|
|
|
|
|
|
| 5 |
|
| 6 |
@tool
|
| 7 |
def wiki_fetch_article(article_title: str) -> str:
|
|
@@ -86,11 +89,34 @@ def wiki_parse_html(page_title: str, section_id: int | None = None) -> str:
|
|
| 86 |
|
| 87 |
# Optional: collapse excessive whitespace
|
| 88 |
text = str(soup)
|
| 89 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
except Exception as e:
|
| 91 |
# Fallback to raw HTML if sanitization fails
|
| 92 |
-
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
except requests.RequestException as e:
|
| 95 |
return f"Error fetching page: {str(e)}"
|
| 96 |
except Exception as e:
|
|
|
|
| 2 |
import wikipediaapi
|
| 3 |
import requests
|
| 4 |
from bs4 import BeautifulSoup
|
| 5 |
+
from langchain_core.messages.utils import count_tokens_approximately, trim_messages
|
| 6 |
+
from langchain_core.messages import HumanMessage
|
| 7 |
+
from agent.config import MAX_TOKENS
|
| 8 |
|
| 9 |
@tool
|
| 10 |
def wiki_fetch_article(article_title: str) -> str:
|
|
|
|
| 89 |
|
| 90 |
# Optional: collapse excessive whitespace
|
| 91 |
text = str(soup)
|
| 92 |
+
|
| 93 |
+
if MAX_TOKENS:
|
| 94 |
+
# Use trim_messages to fit max tokens
|
| 95 |
+
messages = [HumanMessage(content=text)]
|
| 96 |
+
trimmed_messages = trim_messages(
|
| 97 |
+
messages,
|
| 98 |
+
strategy="last",
|
| 99 |
+
token_counter=count_tokens_approximately,
|
| 100 |
+
allow_partial=True,
|
| 101 |
+
max_tokens=MAX_TOKENS,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
return trimmed_messages[0].content if trimmed_messages else text
|
| 105 |
except Exception as e:
|
| 106 |
# Fallback to raw HTML if sanitization fails
|
| 107 |
+
messages = [HumanMessage(content=raw_html)]
|
| 108 |
+
|
| 109 |
+
if MAX_TOKENS:
|
| 110 |
+
# Use trim_messages to fit max tokens
|
| 111 |
+
trimmed_messages = trim_messages(
|
| 112 |
+
messages,
|
| 113 |
+
strategy="last",
|
| 114 |
+
token_counter=count_tokens_approximately,
|
| 115 |
+
allow_partial=True,
|
| 116 |
+
max_tokens=MAX_TOKENS,
|
| 117 |
+
)
|
| 118 |
+
return trimmed_messages[0].content if trimmed_messages else text
|
| 119 |
+
|
| 120 |
except requests.RequestException as e:
|
| 121 |
return f"Error fetching page: {str(e)}"
|
| 122 |
except Exception as e:
|
uv.lock
CHANGED
|
@@ -1289,6 +1289,19 @@ wheels = [
|
|
| 1289 |
{ url = "https://files.pythonhosted.org/packages/84/d8/e1162835d5d6eefaae341c2d1cf750ab53222a421252346905187e53b8a2/langchain_google_genai-2.1.9-py3-none-any.whl", hash = "sha256:8d3aab59706b8f8920a22bcfd63c5000ce430fe61db6ecdec262977d1a0be5b8", size = 49381, upload-time = "2025-08-04T18:51:50.51Z" },
|
| 1290 |
]
|
| 1291 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1292 |
[[package]]
|
| 1293 |
name = "langchain-openai"
|
| 1294 |
version = "0.3.29"
|
|
@@ -1648,6 +1661,19 @@ wheels = [
|
|
| 1648 |
{ url = "https://files.pythonhosted.org/packages/67/0e/35082d13c09c02c011cf21570543d202ad929d961c02a147493cb0c2bdf5/numpy-2.2.6-cp313-cp313t-win_amd64.whl", hash = "sha256:6031dd6dfecc0cf9f668681a37648373bddd6421fff6c66ec1624eed0180ee06", size = 12771374, upload-time = "2025-05-17T21:43:35.479Z" },
|
| 1649 |
]
|
| 1650 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1651 |
[[package]]
|
| 1652 |
name = "openai"
|
| 1653 |
version = "1.99.6"
|
|
@@ -1851,6 +1877,7 @@ dependencies = [
|
|
| 1851 |
{ name = "langchain-community" },
|
| 1852 |
{ name = "langchain-deepseek" },
|
| 1853 |
{ name = "langchain-google-genai" },
|
|
|
|
| 1854 |
{ name = "langchain-tavily" },
|
| 1855 |
{ name = "langgraph" },
|
| 1856 |
{ name = "matplotlib" },
|
|
@@ -1887,6 +1914,7 @@ requires-dist = [
|
|
| 1887 |
{ name = "langchain-community", specifier = ">=0.3.27" },
|
| 1888 |
{ name = "langchain-deepseek", specifier = ">=0.1.4" },
|
| 1889 |
{ name = "langchain-google-genai", specifier = ">=2.1.9" },
|
|
|
|
| 1890 |
{ name = "langchain-tavily", specifier = ">=0.2.11" },
|
| 1891 |
{ name = "langgraph", specifier = ">=0.4.8" },
|
| 1892 |
{ name = "matplotlib", specifier = ">=3.10.5" },
|
|
|
|
| 1289 |
{ url = "https://files.pythonhosted.org/packages/84/d8/e1162835d5d6eefaae341c2d1cf750ab53222a421252346905187e53b8a2/langchain_google_genai-2.1.9-py3-none-any.whl", hash = "sha256:8d3aab59706b8f8920a22bcfd63c5000ce430fe61db6ecdec262977d1a0be5b8", size = 49381, upload-time = "2025-08-04T18:51:50.51Z" },
|
| 1290 |
]
|
| 1291 |
|
| 1292 |
+
[[package]]
|
| 1293 |
+
name = "langchain-ollama"
|
| 1294 |
+
version = "0.3.6"
|
| 1295 |
+
source = { registry = "https://pypi.org/simple" }
|
| 1296 |
+
dependencies = [
|
| 1297 |
+
{ name = "langchain-core" },
|
| 1298 |
+
{ name = "ollama" },
|
| 1299 |
+
]
|
| 1300 |
+
sdist = { url = "https://files.pythonhosted.org/packages/82/67/93429a78d6fd40e2addf27e881db37e7f0076d712ffe9759ca0d5e10910e/langchain_ollama-0.3.6.tar.gz", hash = "sha256:4270c4b30b3f3d10850cb9a1183b8c77d616195e0d9717ac745ef7f7f6cc2b6e", size = 30479, upload-time = "2025-07-22T17:26:59.605Z" }
|
| 1301 |
+
wheels = [
|
| 1302 |
+
{ url = "https://files.pythonhosted.org/packages/f3/c5/1e559f5b43d62850ea2b44097afc944f38894eac00e7feef3b42f0428916/langchain_ollama-0.3.6-py3-none-any.whl", hash = "sha256:b339bd3fcf913b8d606ad426ef39e7122695532507fcd85aa96271b3f33dc3df", size = 24535, upload-time = "2025-07-22T17:26:58.556Z" },
|
| 1303 |
+
]
|
| 1304 |
+
|
| 1305 |
[[package]]
|
| 1306 |
name = "langchain-openai"
|
| 1307 |
version = "0.3.29"
|
|
|
|
| 1661 |
{ url = "https://files.pythonhosted.org/packages/67/0e/35082d13c09c02c011cf21570543d202ad929d961c02a147493cb0c2bdf5/numpy-2.2.6-cp313-cp313t-win_amd64.whl", hash = "sha256:6031dd6dfecc0cf9f668681a37648373bddd6421fff6c66ec1624eed0180ee06", size = 12771374, upload-time = "2025-05-17T21:43:35.479Z" },
|
| 1662 |
]
|
| 1663 |
|
| 1664 |
+
[[package]]
|
| 1665 |
+
name = "ollama"
|
| 1666 |
+
version = "0.5.3"
|
| 1667 |
+
source = { registry = "https://pypi.org/simple" }
|
| 1668 |
+
dependencies = [
|
| 1669 |
+
{ name = "httpx" },
|
| 1670 |
+
{ name = "pydantic" },
|
| 1671 |
+
]
|
| 1672 |
+
sdist = { url = "https://files.pythonhosted.org/packages/91/6d/ae96027416dcc2e98c944c050c492789502d7d7c0b95a740f0bb39268632/ollama-0.5.3.tar.gz", hash = "sha256:40b6dff729df3b24e56d4042fd9d37e231cee8e528677e0d085413a1d6692394", size = 43331, upload-time = "2025-08-07T21:44:10.422Z" }
|
| 1673 |
+
wheels = [
|
| 1674 |
+
{ url = "https://files.pythonhosted.org/packages/be/f6/2091e50b8b6c3e6901f6eab283d5efd66fb71c86ddb1b4d68766c3eeba0f/ollama-0.5.3-py3-none-any.whl", hash = "sha256:a8303b413d99a9043dbf77ebf11ced672396b59bec27e6d5db67c88f01b279d2", size = 13490, upload-time = "2025-08-07T21:44:09.353Z" },
|
| 1675 |
+
]
|
| 1676 |
+
|
| 1677 |
[[package]]
|
| 1678 |
name = "openai"
|
| 1679 |
version = "1.99.6"
|
|
|
|
| 1877 |
{ name = "langchain-community" },
|
| 1878 |
{ name = "langchain-deepseek" },
|
| 1879 |
{ name = "langchain-google-genai" },
|
| 1880 |
+
{ name = "langchain-ollama" },
|
| 1881 |
{ name = "langchain-tavily" },
|
| 1882 |
{ name = "langgraph" },
|
| 1883 |
{ name = "matplotlib" },
|
|
|
|
| 1914 |
{ name = "langchain-community", specifier = ">=0.3.27" },
|
| 1915 |
{ name = "langchain-deepseek", specifier = ">=0.1.4" },
|
| 1916 |
{ name = "langchain-google-genai", specifier = ">=2.1.9" },
|
| 1917 |
+
{ name = "langchain-ollama", specifier = ">=0.3.6" },
|
| 1918 |
{ name = "langchain-tavily", specifier = ">=0.2.11" },
|
| 1919 |
{ name = "langgraph", specifier = ">=0.4.8" },
|
| 1920 |
{ name = "matplotlib", specifier = ">=3.10.5" },
|