Spaces:
Sleeping
Sleeping
DVampire
commited on
Commit
Β·
bf5c0e0
1
Parent(s):
d3e5344
update website
Browse files- configs/paper_agent.py +4 -0
- src/config/__init__.py +3 -0
- src/config/config.py +86 -0
- src/database/__init__.py +5 -0
- src/database/db.py +143 -0
- src/logger/__init__.py +10 -0
- src/logger/logger.py +229 -0
- src/utils/__init__.py +8 -0
- src/utils/hf_utils.py +0 -0
- src/utils/path_utils.py +12 -0
- src/utils/singleton.py +25 -0
configs/paper_agent.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
workdir = "workdir"
|
| 2 |
+
tag = "paper_agent"
|
| 3 |
+
exp_path = f"{workdir}/{tag}"
|
| 4 |
+
log_path = "agent.log"
|
src/config/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .config import config
|
| 2 |
+
|
| 3 |
+
__all__ = ['config']
|
src/config/config.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from mmengine import Config as MMConfig
|
| 3 |
+
from argparse import Namespace
|
| 4 |
+
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
load_dotenv(verbose=True)
|
| 7 |
+
|
| 8 |
+
from finworld.utils import assemble_project_path, get_tag_name, Singleton, set_seed
|
| 9 |
+
|
| 10 |
+
def check_level(level: str) -> bool:
|
| 11 |
+
"""
|
| 12 |
+
Check if the level is valid.
|
| 13 |
+
"""
|
| 14 |
+
valid_levels = ['1day', '1min', '5min', '15min', '30min', '1hour', '4hour']
|
| 15 |
+
if level not in valid_levels:
|
| 16 |
+
return False
|
| 17 |
+
return True
|
| 18 |
+
|
| 19 |
+
def process_general(config: MMConfig) -> MMConfig:
|
| 20 |
+
|
| 21 |
+
config.exp_path = assemble_project_path(os.path.join(config.workdir, config.tag))
|
| 22 |
+
os.makedirs(config.exp_path, exist_ok=True)
|
| 23 |
+
|
| 24 |
+
config.log_path = os.path.join(config.exp_path, getattr(config, 'log_path', 'finworld.log'))
|
| 25 |
+
|
| 26 |
+
if "checkpoint_path" in config:
|
| 27 |
+
config.checkpoint_path = os.path.join(config.exp_path, getattr(config, 'checkpoint_path', 'checkpoint'))
|
| 28 |
+
os.makedirs(config.checkpoint_path, exist_ok=True)
|
| 29 |
+
|
| 30 |
+
if "plot_path" in config:
|
| 31 |
+
config.plot_path = os.path.join(config.exp_path, getattr(config, 'plot_path', 'plot'))
|
| 32 |
+
os.makedirs(config.plot_path, exist_ok=True)
|
| 33 |
+
|
| 34 |
+
if "tracker" in config:
|
| 35 |
+
for key, value in config.tracker.items():
|
| 36 |
+
config.tracker[key]['logging_dir'] = os.path.join(config.exp_path, value['logging_dir'])
|
| 37 |
+
|
| 38 |
+
if "seed" in config:
|
| 39 |
+
set_seed(config.seed)
|
| 40 |
+
|
| 41 |
+
return config
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Config(MMConfig, metaclass=Singleton):
|
| 45 |
+
def __init__(self):
|
| 46 |
+
super(Config, self).__init__()
|
| 47 |
+
|
| 48 |
+
def init_config(self, config_path: str, args: Namespace) -> None:
|
| 49 |
+
# Initialize the general configuration
|
| 50 |
+
mmconfig = MMConfig.fromfile(filename=assemble_project_path(config_path))
|
| 51 |
+
if 'cfg_options' not in args or args.cfg_options is None:
|
| 52 |
+
cfg_options = dict()
|
| 53 |
+
else:
|
| 54 |
+
cfg_options = args.cfg_options
|
| 55 |
+
for item in args.__dict__:
|
| 56 |
+
if item not in ['config', 'cfg_options'] and args.__dict__[item] is not None:
|
| 57 |
+
cfg_options[item] = args.__dict__[item]
|
| 58 |
+
mmconfig.merge_from_dict(cfg_options)
|
| 59 |
+
|
| 60 |
+
tag = get_tag_name(
|
| 61 |
+
tag=getattr(mmconfig, 'tag', None),
|
| 62 |
+
assets_name=getattr(mmconfig, 'assets_name', None),
|
| 63 |
+
source=getattr(mmconfig, 'source', None),
|
| 64 |
+
data_type= getattr(mmconfig, 'data_type', None),
|
| 65 |
+
level= getattr(mmconfig, 'level', None),
|
| 66 |
+
)
|
| 67 |
+
mmconfig.tag = tag
|
| 68 |
+
|
| 69 |
+
# Process general configuration
|
| 70 |
+
mmconfig = process_general(mmconfig)
|
| 71 |
+
|
| 72 |
+
# Initialize the price downloader configuration
|
| 73 |
+
if 'downloader' in mmconfig:
|
| 74 |
+
if "assets_path" in mmconfig.downloader:
|
| 75 |
+
mmconfig.downloader.assets_path = assemble_project_path(mmconfig.downloader.assets_path)
|
| 76 |
+
assert check_level(mmconfig.downloader.level), f"Invalid level: {mmconfig.downloader.level}. Valid levels are: ['1day', '1min', '5min', '15min', '30min', '1hour', '4hour']"
|
| 77 |
+
|
| 78 |
+
if 'processor' in mmconfig:
|
| 79 |
+
if "assets_path" in mmconfig.processor:
|
| 80 |
+
mmconfig.processor.assets_path = assemble_project_path(mmconfig.processor.assets_path)
|
| 81 |
+
mmconfig.processor.repo_id = f"{os.getenv('HF_REPO_NAME')}/{mmconfig.processor.repo_id}"
|
| 82 |
+
mmconfig.processor.repo_type = mmconfig.processor.repo_type if 'repo_type' in mmconfig.processor else 'dataset'
|
| 83 |
+
|
| 84 |
+
self.__dict__.update(mmconfig.__dict__)
|
| 85 |
+
|
| 86 |
+
config = Config()
|
src/database/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Database management for paper caching
|
| 2 |
+
|
| 3 |
+
from .db import PapersDatabase, db
|
| 4 |
+
|
| 5 |
+
__all__ = ['PapersDatabase', 'db']
|
src/database/db.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import sqlite3
|
| 4 |
+
from datetime import date, datetime, timedelta
|
| 5 |
+
from typing import Any, Dict, List, Optional
|
| 6 |
+
from contextlib import contextmanager
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PapersDatabase():
|
| 10 |
+
def __init__(self, **kwargs):
|
| 11 |
+
super().__init__(**kwargs)
|
| 12 |
+
self.db_path = None
|
| 13 |
+
|
| 14 |
+
def init_db(self, config):
|
| 15 |
+
"""Initialize the database with required tables"""
|
| 16 |
+
|
| 17 |
+
self.db_path = config.db_path
|
| 18 |
+
|
| 19 |
+
with self.get_connection() as conn:
|
| 20 |
+
cursor = conn.cursor()
|
| 21 |
+
|
| 22 |
+
# Create papers cache table
|
| 23 |
+
cursor.execute('''
|
| 24 |
+
CREATE TABLE IF NOT EXISTS papers_cache (
|
| 25 |
+
date_str TEXT PRIMARY KEY,
|
| 26 |
+
html_content TEXT NOT NULL,
|
| 27 |
+
parsed_cards TEXT NOT NULL,
|
| 28 |
+
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
| 29 |
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 30 |
+
)
|
| 31 |
+
''')
|
| 32 |
+
|
| 33 |
+
# Create latest_date table to track the most recent available date
|
| 34 |
+
cursor.execute('''
|
| 35 |
+
CREATE TABLE IF NOT EXISTS latest_date (
|
| 36 |
+
id INTEGER PRIMARY KEY CHECK (id = 1),
|
| 37 |
+
date_str TEXT NOT NULL,
|
| 38 |
+
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
| 39 |
+
)
|
| 40 |
+
''')
|
| 41 |
+
|
| 42 |
+
# Insert default latest_date record if it doesn't exist
|
| 43 |
+
cursor.execute('''
|
| 44 |
+
INSERT OR IGNORE INTO latest_date (id, date_str)
|
| 45 |
+
VALUES (1, ?)
|
| 46 |
+
''', (date.today().isoformat(),))
|
| 47 |
+
|
| 48 |
+
conn.commit()
|
| 49 |
+
|
| 50 |
+
@contextmanager
|
| 51 |
+
def get_connection(self):
|
| 52 |
+
"""Context manager for database connections"""
|
| 53 |
+
conn = sqlite3.connect(self.db_path)
|
| 54 |
+
conn.row_factory = sqlite3.Row # Enable dict-like access
|
| 55 |
+
try:
|
| 56 |
+
yield conn
|
| 57 |
+
finally:
|
| 58 |
+
conn.close()
|
| 59 |
+
|
| 60 |
+
def get_cached_papers(self, date_str: str) -> Optional[Dict[str, Any]]:
|
| 61 |
+
"""Get cached papers for a specific date"""
|
| 62 |
+
with self.get_connection(self.db_path) as conn:
|
| 63 |
+
cursor = conn.cursor()
|
| 64 |
+
cursor.execute('''
|
| 65 |
+
SELECT parsed_cards, created_at
|
| 66 |
+
FROM papers_cache
|
| 67 |
+
WHERE date_str = ?
|
| 68 |
+
''', (date_str,))
|
| 69 |
+
|
| 70 |
+
row = cursor.fetchone()
|
| 71 |
+
if row:
|
| 72 |
+
return {
|
| 73 |
+
'cards': json.loads(row['parsed_cards']),
|
| 74 |
+
'cached_at': row['created_at']
|
| 75 |
+
}
|
| 76 |
+
return None
|
| 77 |
+
|
| 78 |
+
def cache_papers(self, date_str: str, html_content: str, parsed_cards: List[Dict[str, Any]]):
|
| 79 |
+
"""Cache papers for a specific date"""
|
| 80 |
+
with self.get_connection() as conn:
|
| 81 |
+
cursor = conn.cursor()
|
| 82 |
+
cursor.execute('''
|
| 83 |
+
INSERT OR REPLACE INTO papers_cache
|
| 84 |
+
(date_str, html_content, parsed_cards, updated_at)
|
| 85 |
+
VALUES (?, ?, ?, CURRENT_TIMESTAMP)
|
| 86 |
+
''', (date_str, html_content, json.dumps(parsed_cards)))
|
| 87 |
+
conn.commit()
|
| 88 |
+
|
| 89 |
+
def get_latest_cached_date(self) -> Optional[str]:
|
| 90 |
+
"""Get the latest cached date"""
|
| 91 |
+
with self.get_connection() as conn:
|
| 92 |
+
cursor = conn.cursor()
|
| 93 |
+
cursor.execute('SELECT date_str FROM latest_date WHERE id = 1')
|
| 94 |
+
row = cursor.fetchone()
|
| 95 |
+
return row['date_str'] if row else None
|
| 96 |
+
|
| 97 |
+
def update_latest_date(self, date_str: str):
|
| 98 |
+
"""Update the latest available date"""
|
| 99 |
+
with self.get_connection() as conn:
|
| 100 |
+
cursor = conn.cursor()
|
| 101 |
+
cursor.execute('''
|
| 102 |
+
UPDATE latest_date
|
| 103 |
+
SET date_str = ?, updated_at = CURRENT_TIMESTAMP
|
| 104 |
+
WHERE id = 1
|
| 105 |
+
''', (date_str,))
|
| 106 |
+
conn.commit()
|
| 107 |
+
|
| 108 |
+
def is_cache_fresh(self, date_str: str, max_age_hours: int = 24) -> bool:
|
| 109 |
+
"""Check if cache is fresh (within max_age_hours)"""
|
| 110 |
+
with self.get_connection() as conn:
|
| 111 |
+
cursor = conn.cursor()
|
| 112 |
+
cursor.execute('''
|
| 113 |
+
SELECT updated_at
|
| 114 |
+
FROM papers_cache
|
| 115 |
+
WHERE date_str = ?
|
| 116 |
+
''', (date_str,))
|
| 117 |
+
|
| 118 |
+
row = cursor.fetchone()
|
| 119 |
+
if not row:
|
| 120 |
+
return False
|
| 121 |
+
|
| 122 |
+
cached_time = datetime.fromisoformat(row['updated_at'].replace('Z', '+00:00'))
|
| 123 |
+
age = datetime.now(cached_time.tzinfo) - cached_time
|
| 124 |
+
return age.total_seconds() < max_age_hours * 3600
|
| 125 |
+
|
| 126 |
+
def cleanup_old_cache(self, days_to_keep: int = 7):
|
| 127 |
+
"""Clean up old cache entries"""
|
| 128 |
+
cutoff_date = (datetime.now() - timedelta(days=days_to_keep)).isoformat()
|
| 129 |
+
with self.get_connection() as conn:
|
| 130 |
+
cursor = conn.cursor()
|
| 131 |
+
cursor.execute('''
|
| 132 |
+
DELETE FROM papers_cache
|
| 133 |
+
WHERE updated_at < ?
|
| 134 |
+
''', (cutoff_date,))
|
| 135 |
+
conn.commit()
|
| 136 |
+
|
| 137 |
+
def __str__(self):
|
| 138 |
+
return f"PapersDatabase(db_path={self.db_path})"
|
| 139 |
+
|
| 140 |
+
def __repr__(self):
|
| 141 |
+
return self.__str__()
|
| 142 |
+
|
| 143 |
+
db = PapersDatabase()
|
src/logger/__init__.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .logger import logger, LogLevel, AgentLogger, YELLOW_HEX
|
| 2 |
+
from .monitor import Monitor, Timing, TokenUsage
|
| 3 |
+
|
| 4 |
+
__all__ = ["logger",
|
| 5 |
+
"LogLevel",
|
| 6 |
+
"AgentLogger",
|
| 7 |
+
"Monitor",
|
| 8 |
+
"YELLOW_HEX",
|
| 9 |
+
"Timing",
|
| 10 |
+
"TokenUsage"]
|
src/logger/logger.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import json
|
| 3 |
+
from enum import IntEnum
|
| 4 |
+
|
| 5 |
+
from rich import box
|
| 6 |
+
from rich.console import Console, Group
|
| 7 |
+
from rich.panel import Panel
|
| 8 |
+
from rich.rule import Rule
|
| 9 |
+
from rich.syntax import Syntax
|
| 10 |
+
from rich.table import Table
|
| 11 |
+
from rich.tree import Tree
|
| 12 |
+
|
| 13 |
+
from src.utils import (
|
| 14 |
+
escape_code_brackets,
|
| 15 |
+
Singleton
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
YELLOW_HEX = "#d4b702"
|
| 19 |
+
|
| 20 |
+
class LogLevel(IntEnum):
|
| 21 |
+
OFF = -1 # No output
|
| 22 |
+
ERROR = 0 # Only errors
|
| 23 |
+
INFO = 1 # Normal output (default)
|
| 24 |
+
DEBUG = 2 # Detailed output
|
| 25 |
+
|
| 26 |
+
class AgentLogger(logging.Logger, metaclass=Singleton):
|
| 27 |
+
def __init__(self, name="logger", level=logging.INFO):
|
| 28 |
+
# Initialize the parent class
|
| 29 |
+
super().__init__(name, level)
|
| 30 |
+
|
| 31 |
+
# Define a formatter for log messages
|
| 32 |
+
self.formatter = logging.Formatter(
|
| 33 |
+
fmt="\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s",
|
| 34 |
+
datefmt="%H:%M:%S",
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
def init_logger(self, log_path: str, level=logging.INFO):
|
| 38 |
+
"""
|
| 39 |
+
Initialize the logger with a file path and optional main process check.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
log_path (str): The log file path.
|
| 43 |
+
level (int, optional): The logging level. Defaults to logging.INFO.
|
| 44 |
+
accelerator (Accelerator, optional): Accelerator instance to determine the main process.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
# Add a console handler for logging to the console
|
| 48 |
+
console_handler = logging.StreamHandler()
|
| 49 |
+
console_handler.setLevel(level)
|
| 50 |
+
console_handler.setFormatter(self.formatter)
|
| 51 |
+
self.addHandler(console_handler)
|
| 52 |
+
|
| 53 |
+
# Add a file handler for logging to the file
|
| 54 |
+
file_handler = logging.FileHandler(
|
| 55 |
+
log_path, mode="a"
|
| 56 |
+
) # 'a' mode appends to the file
|
| 57 |
+
file_handler.setLevel(level)
|
| 58 |
+
file_handler.setFormatter(self.formatter)
|
| 59 |
+
self.addHandler(file_handler)
|
| 60 |
+
|
| 61 |
+
self.console = Console(width=100)
|
| 62 |
+
self.file_console = Console(file=open(log_path, "a"), width=100)
|
| 63 |
+
|
| 64 |
+
# Prevent duplicate logs from propagating to the root logger
|
| 65 |
+
self.propagate = False
|
| 66 |
+
|
| 67 |
+
def log(self, *args, level: int | str | LogLevel = LogLevel.INFO, **kwargs) -> None:
|
| 68 |
+
"""Logs a message to the console.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
level (LogLevel, optional): Defaults to LogLevel.INFO.
|
| 72 |
+
"""
|
| 73 |
+
if isinstance(level, str):
|
| 74 |
+
level = LogLevel[level.upper()]
|
| 75 |
+
if level <= self.level:
|
| 76 |
+
self.info(*args, **kwargs)
|
| 77 |
+
|
| 78 |
+
def info(self, msg, *args, **kwargs):
|
| 79 |
+
"""
|
| 80 |
+
Overridden info method with stacklevel adjustment for correct log location.
|
| 81 |
+
"""
|
| 82 |
+
if isinstance(msg, (Rule, Panel, Group, Tree, Table, Syntax)):
|
| 83 |
+
self.console.print(msg)
|
| 84 |
+
self.file_console.print(msg)
|
| 85 |
+
else:
|
| 86 |
+
kwargs.setdefault(
|
| 87 |
+
"stacklevel", 2
|
| 88 |
+
) # Adjust stack level to show the actual caller
|
| 89 |
+
if "style" in kwargs:
|
| 90 |
+
kwargs.pop("style")
|
| 91 |
+
if "level" in kwargs:
|
| 92 |
+
kwargs.pop("level")
|
| 93 |
+
super().info(msg, *args, **kwargs)
|
| 94 |
+
|
| 95 |
+
def warning(self, msg, *args, **kwargs):
|
| 96 |
+
kwargs.setdefault("stacklevel", 2)
|
| 97 |
+
super().warning(msg, *args, **kwargs)
|
| 98 |
+
|
| 99 |
+
def error(self, msg, *args, **kwargs):
|
| 100 |
+
kwargs.setdefault("stacklevel", 2)
|
| 101 |
+
super().error(msg, *args, **kwargs)
|
| 102 |
+
|
| 103 |
+
def critical(self, msg, *args, **kwargs):
|
| 104 |
+
kwargs.setdefault("stacklevel", 2)
|
| 105 |
+
super().critical(msg, *args, **kwargs)
|
| 106 |
+
|
| 107 |
+
def debug(self, msg, *args, **kwargs):
|
| 108 |
+
kwargs.setdefault("stacklevel", 2)
|
| 109 |
+
super().debug(msg, *args, **kwargs)
|
| 110 |
+
|
| 111 |
+
def log_error(self, error_message: str) -> None:
|
| 112 |
+
self.info(escape_code_brackets(error_message), style="bold red", level=LogLevel.ERROR)
|
| 113 |
+
|
| 114 |
+
def log_markdown(self, content: str, title: str | None = None, level=LogLevel.INFO, style=YELLOW_HEX) -> None:
|
| 115 |
+
markdown_content = Syntax(
|
| 116 |
+
content,
|
| 117 |
+
lexer="markdown",
|
| 118 |
+
theme="github-dark",
|
| 119 |
+
word_wrap=True,
|
| 120 |
+
)
|
| 121 |
+
if title:
|
| 122 |
+
self.info(
|
| 123 |
+
Group(
|
| 124 |
+
Rule(
|
| 125 |
+
"[bold italic]" + title,
|
| 126 |
+
align="left",
|
| 127 |
+
style=style,
|
| 128 |
+
),
|
| 129 |
+
markdown_content,
|
| 130 |
+
),
|
| 131 |
+
level=level,
|
| 132 |
+
)
|
| 133 |
+
else:
|
| 134 |
+
self.info(markdown_content, level=level)
|
| 135 |
+
|
| 136 |
+
def log_code(self, title: str, content: str, level: int = LogLevel.INFO) -> None:
|
| 137 |
+
self.info(
|
| 138 |
+
Panel(
|
| 139 |
+
Syntax(
|
| 140 |
+
content,
|
| 141 |
+
lexer="python",
|
| 142 |
+
theme="monokai",
|
| 143 |
+
word_wrap=True,
|
| 144 |
+
),
|
| 145 |
+
title="[bold]" + title,
|
| 146 |
+
title_align="left",
|
| 147 |
+
box=box.HORIZONTALS,
|
| 148 |
+
),
|
| 149 |
+
level=level,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
def log_rule(self, title: str, level: int = LogLevel.INFO) -> None:
|
| 153 |
+
self.info(
|
| 154 |
+
Rule(
|
| 155 |
+
"[bold]" + title,
|
| 156 |
+
characters="β",
|
| 157 |
+
style=YELLOW_HEX,
|
| 158 |
+
),
|
| 159 |
+
level=LogLevel.INFO,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def log_task(self, content: str, subtitle: str, title: str | None = None, level: LogLevel = LogLevel.INFO) -> None:
|
| 163 |
+
self.info(
|
| 164 |
+
Panel(
|
| 165 |
+
f"\n[bold]{escape_code_brackets(content)}\n",
|
| 166 |
+
title="[bold]New run" + (f" - {title}" if title else ""),
|
| 167 |
+
subtitle=subtitle,
|
| 168 |
+
border_style=YELLOW_HEX,
|
| 169 |
+
subtitle_align="left",
|
| 170 |
+
),
|
| 171 |
+
level=level,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
def log_messages(self, messages: list[dict], level: LogLevel = LogLevel.DEBUG) -> None:
|
| 175 |
+
messages_as_string = "\n".join([json.dumps(dict(message), indent=4, ensure_ascii=False) for message in messages])
|
| 176 |
+
self.info(
|
| 177 |
+
Syntax(
|
| 178 |
+
messages_as_string,
|
| 179 |
+
lexer="markdown",
|
| 180 |
+
theme="github-dark",
|
| 181 |
+
word_wrap=True,
|
| 182 |
+
),
|
| 183 |
+
level=level,
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def visualize_agent_tree(self, agent):
|
| 187 |
+
def create_tools_section(tools_dict):
|
| 188 |
+
table = Table(show_header=True, header_style="bold")
|
| 189 |
+
table.add_column("Name", style="#1E90FF")
|
| 190 |
+
table.add_column("Description")
|
| 191 |
+
table.add_column("Arguments")
|
| 192 |
+
|
| 193 |
+
for name, tool in tools_dict.items():
|
| 194 |
+
args = [
|
| 195 |
+
f"{arg_name} (`{info.get('type', 'Any')}`{', optional' if info.get('optional') else ''}): {info.get('description', '')}"
|
| 196 |
+
for arg_name, info in getattr(tool, "inputs", {}).items()
|
| 197 |
+
]
|
| 198 |
+
table.add_row(name, getattr(tool, "description", str(tool)), "\n".join(args))
|
| 199 |
+
|
| 200 |
+
return Group("π οΈ [italic #1E90FF]Tools:[/italic #1E90FF]", table)
|
| 201 |
+
|
| 202 |
+
def get_agent_headline(agent, name: str | None = None):
|
| 203 |
+
name_headline = f"{name} | " if name else ""
|
| 204 |
+
return f"[bold {YELLOW_HEX}]{name_headline}{agent.__class__.__name__} | {agent.model.model_id}"
|
| 205 |
+
|
| 206 |
+
def build_agent_tree(parent_tree, agent_obj):
|
| 207 |
+
"""Recursively builds the agent tree."""
|
| 208 |
+
parent_tree.add(create_tools_section(agent_obj.tools))
|
| 209 |
+
|
| 210 |
+
if agent_obj.managed_agents:
|
| 211 |
+
agents_branch = parent_tree.add("π€ [italic #1E90FF]Managed agents:")
|
| 212 |
+
for name, managed_agent in agent_obj.managed_agents.items():
|
| 213 |
+
agent_tree = agents_branch.add(get_agent_headline(managed_agent, name))
|
| 214 |
+
if managed_agent.__class__.__name__ == "CodeAgent":
|
| 215 |
+
agent_tree.add(
|
| 216 |
+
f"β
[italic #1E90FF]Authorized imports:[/italic #1E90FF] {managed_agent.additional_authorized_imports}"
|
| 217 |
+
)
|
| 218 |
+
agent_tree.add(f"π [italic #1E90FF]Description:[/italic #1E90FF] {managed_agent.description}")
|
| 219 |
+
build_agent_tree(agent_tree, managed_agent)
|
| 220 |
+
|
| 221 |
+
main_tree = Tree(get_agent_headline(agent))
|
| 222 |
+
if agent.__class__.__name__ == "CodeAgent":
|
| 223 |
+
main_tree.add(
|
| 224 |
+
f"β
[italic #1E90FF]Authorized imports:[/italic #1E90FF] {agent.additional_authorized_imports}"
|
| 225 |
+
)
|
| 226 |
+
build_agent_tree(main_tree, agent)
|
| 227 |
+
self.console.print(main_tree)
|
| 228 |
+
|
| 229 |
+
logger = AgentLogger()
|
src/utils/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .path_utils import get_project_root, assemble_project_path
|
| 2 |
+
from .singleton import Singleton
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"get_project_root",
|
| 6 |
+
"assemble_project_path",
|
| 7 |
+
"Singleton"
|
| 8 |
+
]
|
src/utils/hf_utils.py
ADDED
|
File without changes
|
src/utils/path_utils.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
def get_project_root():
|
| 5 |
+
root = str(Path(__file__).resolve().parents[2])
|
| 6 |
+
return root
|
| 7 |
+
|
| 8 |
+
def assemble_project_path(path):
|
| 9 |
+
"""Assemble a path relative to the project root directory"""
|
| 10 |
+
if not os.path.isabs(path):
|
| 11 |
+
path = os.path.join(get_project_root(), path)
|
| 12 |
+
return path
|
src/utils/singleton.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""A singleton metaclass for ensuring only one instance of a class."""
|
| 2 |
+
|
| 3 |
+
import abc
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Singleton(abc.ABCMeta, type):
|
| 7 |
+
"""
|
| 8 |
+
Singleton metaclass for ensuring only one instance of a class.
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
_instances = {}
|
| 12 |
+
|
| 13 |
+
def __call__(cls, *args, **kwargs):
|
| 14 |
+
"""Call method for the singleton metaclass."""
|
| 15 |
+
if cls not in cls._instances:
|
| 16 |
+
cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
|
| 17 |
+
return cls._instances[cls]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class AbstractSingleton(abc.ABC, metaclass=Singleton):
|
| 21 |
+
"""
|
| 22 |
+
Abstract singleton class for ensuring only one instance of a class.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
pass
|