Spaces:
				
			
			
	
			
			
		Sleeping
		
	
	
	
			
			
	
	
	
	
		
		
		Sleeping
		
	
		DVampire
		
	commited on
		
		
					Commit 
							
							Β·
						
						bf5c0e0
	
1
								Parent(s):
							
							d3e5344
								
update website
Browse files- configs/paper_agent.py +4 -0
- src/config/__init__.py +3 -0
- src/config/config.py +86 -0
- src/database/__init__.py +5 -0
- src/database/db.py +143 -0
- src/logger/__init__.py +10 -0
- src/logger/logger.py +229 -0
- src/utils/__init__.py +8 -0
- src/utils/hf_utils.py +0 -0
- src/utils/path_utils.py +12 -0
- src/utils/singleton.py +25 -0
    	
        configs/paper_agent.py
    ADDED
    
    | @@ -0,0 +1,4 @@ | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            workdir = "workdir"
         | 
| 2 | 
            +
            tag = "paper_agent"
         | 
| 3 | 
            +
            exp_path = f"{workdir}/{tag}"
         | 
| 4 | 
            +
            log_path = "agent.log"
         | 
    	
        src/config/__init__.py
    ADDED
    
    | @@ -0,0 +1,3 @@ | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .config import config
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            __all__ = ['config']
         | 
    	
        src/config/config.py
    ADDED
    
    | @@ -0,0 +1,86 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            from mmengine import Config as MMConfig
         | 
| 3 | 
            +
            from argparse import Namespace
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from dotenv import load_dotenv
         | 
| 6 | 
            +
            load_dotenv(verbose=True)
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from finworld.utils import assemble_project_path, get_tag_name, Singleton, set_seed
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            def check_level(level: str) -> bool:
         | 
| 11 | 
            +
                """
         | 
| 12 | 
            +
                Check if the level is valid.
         | 
| 13 | 
            +
                """
         | 
| 14 | 
            +
                valid_levels = ['1day', '1min', '5min', '15min', '30min', '1hour', '4hour']
         | 
| 15 | 
            +
                if level not in valid_levels:
         | 
| 16 | 
            +
                    return False
         | 
| 17 | 
            +
                return True
         | 
| 18 | 
            +
             | 
| 19 | 
            +
            def process_general(config: MMConfig) -> MMConfig:
         | 
| 20 | 
            +
             | 
| 21 | 
            +
                config.exp_path = assemble_project_path(os.path.join(config.workdir, config.tag))
         | 
| 22 | 
            +
                os.makedirs(config.exp_path, exist_ok=True)
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                config.log_path = os.path.join(config.exp_path, getattr(config, 'log_path', 'finworld.log'))
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                if "checkpoint_path" in config:
         | 
| 27 | 
            +
                    config.checkpoint_path = os.path.join(config.exp_path, getattr(config, 'checkpoint_path', 'checkpoint'))
         | 
| 28 | 
            +
                    os.makedirs(config.checkpoint_path, exist_ok=True)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                if "plot_path" in config:
         | 
| 31 | 
            +
                    config.plot_path = os.path.join(config.exp_path, getattr(config, 'plot_path', 'plot'))
         | 
| 32 | 
            +
                    os.makedirs(config.plot_path, exist_ok=True)
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                if "tracker" in config:
         | 
| 35 | 
            +
                    for key, value in config.tracker.items():
         | 
| 36 | 
            +
                        config.tracker[key]['logging_dir'] = os.path.join(config.exp_path, value['logging_dir'])
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                if "seed" in config:
         | 
| 39 | 
            +
                    set_seed(config.seed)
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                return config
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            class Config(MMConfig, metaclass=Singleton):
         | 
| 45 | 
            +
                def __init__(self):
         | 
| 46 | 
            +
                    super(Config, self).__init__()
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                def init_config(self, config_path: str, args: Namespace) -> None:
         | 
| 49 | 
            +
                    # Initialize the general configuration
         | 
| 50 | 
            +
                    mmconfig = MMConfig.fromfile(filename=assemble_project_path(config_path))
         | 
| 51 | 
            +
                    if 'cfg_options' not in args or args.cfg_options is None:
         | 
| 52 | 
            +
                        cfg_options = dict()
         | 
| 53 | 
            +
                    else:
         | 
| 54 | 
            +
                        cfg_options = args.cfg_options
         | 
| 55 | 
            +
                    for item in args.__dict__:
         | 
| 56 | 
            +
                        if item not in ['config', 'cfg_options'] and args.__dict__[item] is not None:
         | 
| 57 | 
            +
                            cfg_options[item] = args.__dict__[item]
         | 
| 58 | 
            +
                    mmconfig.merge_from_dict(cfg_options)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    tag = get_tag_name(
         | 
| 61 | 
            +
                        tag=getattr(mmconfig, 'tag', None),
         | 
| 62 | 
            +
                        assets_name=getattr(mmconfig, 'assets_name', None),
         | 
| 63 | 
            +
                        source=getattr(mmconfig, 'source', None),
         | 
| 64 | 
            +
                        data_type= getattr(mmconfig, 'data_type', None),
         | 
| 65 | 
            +
                        level= getattr(mmconfig, 'level', None),
         | 
| 66 | 
            +
                    )
         | 
| 67 | 
            +
                    mmconfig.tag = tag
         | 
| 68 | 
            +
             | 
| 69 | 
            +
                    # Process general configuration
         | 
| 70 | 
            +
                    mmconfig = process_general(mmconfig)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                    # Initialize the price downloader configuration
         | 
| 73 | 
            +
                    if 'downloader' in mmconfig:
         | 
| 74 | 
            +
                        if "assets_path" in mmconfig.downloader:
         | 
| 75 | 
            +
                            mmconfig.downloader.assets_path = assemble_project_path(mmconfig.downloader.assets_path)
         | 
| 76 | 
            +
                            assert check_level(mmconfig.downloader.level), f"Invalid level: {mmconfig.downloader.level}. Valid levels are: ['1day', '1min', '5min', '15min', '30min', '1hour', '4hour']"
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                    if 'processor' in mmconfig:
         | 
| 79 | 
            +
                        if "assets_path" in mmconfig.processor:
         | 
| 80 | 
            +
                            mmconfig.processor.assets_path = assemble_project_path(mmconfig.processor.assets_path)
         | 
| 81 | 
            +
                        mmconfig.processor.repo_id = f"{os.getenv('HF_REPO_NAME')}/{mmconfig.processor.repo_id}"
         | 
| 82 | 
            +
                        mmconfig.processor.repo_type = mmconfig.processor.repo_type if 'repo_type' in mmconfig.processor else 'dataset'
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    self.__dict__.update(mmconfig.__dict__)
         | 
| 85 | 
            +
             | 
| 86 | 
            +
            config = Config()
         | 
    	
        src/database/__init__.py
    ADDED
    
    | @@ -0,0 +1,5 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Database management for paper caching
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            from .db import PapersDatabase, db
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            __all__ = ['PapersDatabase', 'db']
         | 
    	
        src/database/db.py
    ADDED
    
    | @@ -0,0 +1,143 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            import sqlite3
         | 
| 4 | 
            +
            from datetime import date, datetime, timedelta
         | 
| 5 | 
            +
            from typing import Any, Dict, List, Optional
         | 
| 6 | 
            +
            from contextlib import contextmanager
         | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
| 9 | 
            +
            class PapersDatabase():
         | 
| 10 | 
            +
                def __init__(self, **kwargs):
         | 
| 11 | 
            +
                    super().__init__(**kwargs)
         | 
| 12 | 
            +
                    self.db_path = None
         | 
| 13 | 
            +
                
         | 
| 14 | 
            +
                def init_db(self, config):
         | 
| 15 | 
            +
                    """Initialize the database with required tables"""
         | 
| 16 | 
            +
                    
         | 
| 17 | 
            +
                    self.db_path = config.db_path
         | 
| 18 | 
            +
                    
         | 
| 19 | 
            +
                    with self.get_connection() as conn:
         | 
| 20 | 
            +
                        cursor = conn.cursor()
         | 
| 21 | 
            +
                        
         | 
| 22 | 
            +
                        # Create papers cache table
         | 
| 23 | 
            +
                        cursor.execute('''
         | 
| 24 | 
            +
                            CREATE TABLE IF NOT EXISTS papers_cache (
         | 
| 25 | 
            +
                                date_str TEXT PRIMARY KEY,
         | 
| 26 | 
            +
                                html_content TEXT NOT NULL,
         | 
| 27 | 
            +
                                parsed_cards TEXT NOT NULL,
         | 
| 28 | 
            +
                                created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
         | 
| 29 | 
            +
                                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
         | 
| 30 | 
            +
                            )
         | 
| 31 | 
            +
                        ''')
         | 
| 32 | 
            +
                        
         | 
| 33 | 
            +
                        # Create latest_date table to track the most recent available date
         | 
| 34 | 
            +
                        cursor.execute('''
         | 
| 35 | 
            +
                            CREATE TABLE IF NOT EXISTS latest_date (
         | 
| 36 | 
            +
                                id INTEGER PRIMARY KEY CHECK (id = 1),
         | 
| 37 | 
            +
                                date_str TEXT NOT NULL,
         | 
| 38 | 
            +
                                updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
         | 
| 39 | 
            +
                            )
         | 
| 40 | 
            +
                        ''')
         | 
| 41 | 
            +
                        
         | 
| 42 | 
            +
                        # Insert default latest_date record if it doesn't exist
         | 
| 43 | 
            +
                        cursor.execute('''
         | 
| 44 | 
            +
                            INSERT OR IGNORE INTO latest_date (id, date_str) 
         | 
| 45 | 
            +
                            VALUES (1, ?)
         | 
| 46 | 
            +
                        ''', (date.today().isoformat(),))
         | 
| 47 | 
            +
                        
         | 
| 48 | 
            +
                        conn.commit()
         | 
| 49 | 
            +
                
         | 
| 50 | 
            +
                @contextmanager
         | 
| 51 | 
            +
                def get_connection(self):
         | 
| 52 | 
            +
                    """Context manager for database connections"""
         | 
| 53 | 
            +
                    conn = sqlite3.connect(self.db_path)
         | 
| 54 | 
            +
                    conn.row_factory = sqlite3.Row  # Enable dict-like access
         | 
| 55 | 
            +
                    try:
         | 
| 56 | 
            +
                        yield conn
         | 
| 57 | 
            +
                    finally:
         | 
| 58 | 
            +
                        conn.close()
         | 
| 59 | 
            +
                
         | 
| 60 | 
            +
                def get_cached_papers(self, date_str: str) -> Optional[Dict[str, Any]]:
         | 
| 61 | 
            +
                    """Get cached papers for a specific date"""
         | 
| 62 | 
            +
                    with self.get_connection(self.db_path) as conn:
         | 
| 63 | 
            +
                        cursor = conn.cursor()
         | 
| 64 | 
            +
                        cursor.execute('''
         | 
| 65 | 
            +
                            SELECT parsed_cards, created_at 
         | 
| 66 | 
            +
                            FROM papers_cache 
         | 
| 67 | 
            +
                            WHERE date_str = ?
         | 
| 68 | 
            +
                        ''', (date_str,))
         | 
| 69 | 
            +
                        
         | 
| 70 | 
            +
                        row = cursor.fetchone()
         | 
| 71 | 
            +
                        if row:
         | 
| 72 | 
            +
                            return {
         | 
| 73 | 
            +
                                'cards': json.loads(row['parsed_cards']),
         | 
| 74 | 
            +
                                'cached_at': row['created_at']
         | 
| 75 | 
            +
                            }
         | 
| 76 | 
            +
                        return None
         | 
| 77 | 
            +
                
         | 
| 78 | 
            +
                def cache_papers(self, date_str: str, html_content: str, parsed_cards: List[Dict[str, Any]]):
         | 
| 79 | 
            +
                    """Cache papers for a specific date"""
         | 
| 80 | 
            +
                    with self.get_connection() as conn:
         | 
| 81 | 
            +
                        cursor = conn.cursor()
         | 
| 82 | 
            +
                        cursor.execute('''
         | 
| 83 | 
            +
                            INSERT OR REPLACE INTO papers_cache 
         | 
| 84 | 
            +
                            (date_str, html_content, parsed_cards, updated_at) 
         | 
| 85 | 
            +
                            VALUES (?, ?, ?, CURRENT_TIMESTAMP)
         | 
| 86 | 
            +
                        ''', (date_str, html_content, json.dumps(parsed_cards)))
         | 
| 87 | 
            +
                        conn.commit()
         | 
| 88 | 
            +
                
         | 
| 89 | 
            +
                def get_latest_cached_date(self) -> Optional[str]:
         | 
| 90 | 
            +
                    """Get the latest cached date"""
         | 
| 91 | 
            +
                    with self.get_connection() as conn:
         | 
| 92 | 
            +
                        cursor = conn.cursor()
         | 
| 93 | 
            +
                        cursor.execute('SELECT date_str FROM latest_date WHERE id = 1')
         | 
| 94 | 
            +
                        row = cursor.fetchone()
         | 
| 95 | 
            +
                        return row['date_str'] if row else None
         | 
| 96 | 
            +
                
         | 
| 97 | 
            +
                def update_latest_date(self, date_str: str):
         | 
| 98 | 
            +
                    """Update the latest available date"""
         | 
| 99 | 
            +
                    with self.get_connection() as conn:
         | 
| 100 | 
            +
                        cursor = conn.cursor()
         | 
| 101 | 
            +
                        cursor.execute('''
         | 
| 102 | 
            +
                            UPDATE latest_date 
         | 
| 103 | 
            +
                            SET date_str = ?, updated_at = CURRENT_TIMESTAMP 
         | 
| 104 | 
            +
                            WHERE id = 1
         | 
| 105 | 
            +
                        ''', (date_str,))
         | 
| 106 | 
            +
                        conn.commit()
         | 
| 107 | 
            +
                
         | 
| 108 | 
            +
                def is_cache_fresh(self, date_str: str, max_age_hours: int = 24) -> bool:
         | 
| 109 | 
            +
                    """Check if cache is fresh (within max_age_hours)"""
         | 
| 110 | 
            +
                    with self.get_connection() as conn:
         | 
| 111 | 
            +
                        cursor = conn.cursor()
         | 
| 112 | 
            +
                        cursor.execute('''
         | 
| 113 | 
            +
                            SELECT updated_at 
         | 
| 114 | 
            +
                            FROM papers_cache 
         | 
| 115 | 
            +
                            WHERE date_str = ?
         | 
| 116 | 
            +
                        ''', (date_str,))
         | 
| 117 | 
            +
                        
         | 
| 118 | 
            +
                        row = cursor.fetchone()
         | 
| 119 | 
            +
                        if not row:
         | 
| 120 | 
            +
                            return False
         | 
| 121 | 
            +
                        
         | 
| 122 | 
            +
                        cached_time = datetime.fromisoformat(row['updated_at'].replace('Z', '+00:00'))
         | 
| 123 | 
            +
                        age = datetime.now(cached_time.tzinfo) - cached_time
         | 
| 124 | 
            +
                        return age.total_seconds() < max_age_hours * 3600
         | 
| 125 | 
            +
                
         | 
| 126 | 
            +
                def cleanup_old_cache(self, days_to_keep: int = 7):
         | 
| 127 | 
            +
                    """Clean up old cache entries"""
         | 
| 128 | 
            +
                    cutoff_date = (datetime.now() - timedelta(days=days_to_keep)).isoformat()
         | 
| 129 | 
            +
                    with self.get_connection() as conn:
         | 
| 130 | 
            +
                        cursor = conn.cursor()
         | 
| 131 | 
            +
                        cursor.execute('''
         | 
| 132 | 
            +
                            DELETE FROM papers_cache 
         | 
| 133 | 
            +
                            WHERE updated_at < ?
         | 
| 134 | 
            +
                        ''', (cutoff_date,))
         | 
| 135 | 
            +
                        conn.commit()
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                def __str__(self):
         | 
| 138 | 
            +
                    return f"PapersDatabase(db_path={self.db_path})"
         | 
| 139 | 
            +
             | 
| 140 | 
            +
                def __repr__(self):
         | 
| 141 | 
            +
                    return self.__str__()
         | 
| 142 | 
            +
             | 
| 143 | 
            +
            db = PapersDatabase()
         | 
    	
        src/logger/__init__.py
    ADDED
    
    | @@ -0,0 +1,10 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .logger import logger, LogLevel, AgentLogger, YELLOW_HEX
         | 
| 2 | 
            +
            from .monitor import Monitor, Timing, TokenUsage
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            __all__ = ["logger",
         | 
| 5 | 
            +
                       "LogLevel",
         | 
| 6 | 
            +
                       "AgentLogger",
         | 
| 7 | 
            +
                       "Monitor",
         | 
| 8 | 
            +
                       "YELLOW_HEX",
         | 
| 9 | 
            +
                       "Timing",
         | 
| 10 | 
            +
                       "TokenUsage"]
         | 
    	
        src/logger/logger.py
    ADDED
    
    | @@ -0,0 +1,229 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import logging
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            from enum import IntEnum
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            from rich import box
         | 
| 6 | 
            +
            from rich.console import Console, Group
         | 
| 7 | 
            +
            from rich.panel import Panel
         | 
| 8 | 
            +
            from rich.rule import Rule
         | 
| 9 | 
            +
            from rich.syntax import Syntax
         | 
| 10 | 
            +
            from rich.table import Table
         | 
| 11 | 
            +
            from rich.tree import Tree
         | 
| 12 | 
            +
             | 
| 13 | 
            +
            from src.utils import (
         | 
| 14 | 
            +
                escape_code_brackets,
         | 
| 15 | 
            +
                Singleton
         | 
| 16 | 
            +
            )
         | 
| 17 | 
            +
             | 
| 18 | 
            +
            YELLOW_HEX = "#d4b702"
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            class LogLevel(IntEnum):
         | 
| 21 | 
            +
                OFF = -1  # No output
         | 
| 22 | 
            +
                ERROR = 0  # Only errors
         | 
| 23 | 
            +
                INFO = 1  # Normal output (default)
         | 
| 24 | 
            +
                DEBUG = 2  # Detailed output
         | 
| 25 | 
            +
             | 
| 26 | 
            +
            class AgentLogger(logging.Logger, metaclass=Singleton):
         | 
| 27 | 
            +
                def __init__(self, name="logger", level=logging.INFO):
         | 
| 28 | 
            +
                    # Initialize the parent class
         | 
| 29 | 
            +
                    super().__init__(name, level)
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                    # Define a formatter for log messages
         | 
| 32 | 
            +
                    self.formatter = logging.Formatter(
         | 
| 33 | 
            +
                        fmt="\033[92m%(asctime)s - %(name)s:%(levelname)s\033[0m: %(filename)s:%(lineno)s - %(message)s",
         | 
| 34 | 
            +
                        datefmt="%H:%M:%S",
         | 
| 35 | 
            +
                    )
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def init_logger(self, log_path: str, level=logging.INFO):
         | 
| 38 | 
            +
                    """
         | 
| 39 | 
            +
                    Initialize the logger with a file path and optional main process check.
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                    Args:
         | 
| 42 | 
            +
                        log_path (str): The log file path.
         | 
| 43 | 
            +
                        level (int, optional): The logging level. Defaults to logging.INFO.
         | 
| 44 | 
            +
                        accelerator (Accelerator, optional): Accelerator instance to determine the main process.
         | 
| 45 | 
            +
                    """
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                    # Add a console handler for logging to the console
         | 
| 48 | 
            +
                    console_handler = logging.StreamHandler()
         | 
| 49 | 
            +
                    console_handler.setLevel(level)
         | 
| 50 | 
            +
                    console_handler.setFormatter(self.formatter)
         | 
| 51 | 
            +
                    self.addHandler(console_handler)
         | 
| 52 | 
            +
             | 
| 53 | 
            +
                    # Add a file handler for logging to the file
         | 
| 54 | 
            +
                    file_handler = logging.FileHandler(
         | 
| 55 | 
            +
                        log_path, mode="a"
         | 
| 56 | 
            +
                    )  # 'a' mode appends to the file
         | 
| 57 | 
            +
                    file_handler.setLevel(level)
         | 
| 58 | 
            +
                    file_handler.setFormatter(self.formatter)
         | 
| 59 | 
            +
                    self.addHandler(file_handler)
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    self.console = Console(width=100)
         | 
| 62 | 
            +
                    self.file_console = Console(file=open(log_path, "a"), width=100)
         | 
| 63 | 
            +
             | 
| 64 | 
            +
                    # Prevent duplicate logs from propagating to the root logger
         | 
| 65 | 
            +
                    self.propagate = False
         | 
| 66 | 
            +
             | 
| 67 | 
            +
                def log(self, *args, level: int | str | LogLevel = LogLevel.INFO, **kwargs) -> None:
         | 
| 68 | 
            +
                    """Logs a message to the console.
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    Args:
         | 
| 71 | 
            +
                        level (LogLevel, optional): Defaults to LogLevel.INFO.
         | 
| 72 | 
            +
                    """
         | 
| 73 | 
            +
                    if isinstance(level, str):
         | 
| 74 | 
            +
                        level = LogLevel[level.upper()]
         | 
| 75 | 
            +
                    if level <= self.level:
         | 
| 76 | 
            +
                        self.info(*args, **kwargs)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                def info(self, msg, *args, **kwargs):
         | 
| 79 | 
            +
                    """
         | 
| 80 | 
            +
                    Overridden info method with stacklevel adjustment for correct log location.
         | 
| 81 | 
            +
                    """
         | 
| 82 | 
            +
                    if isinstance(msg, (Rule, Panel, Group, Tree, Table, Syntax)):
         | 
| 83 | 
            +
                        self.console.print(msg)
         | 
| 84 | 
            +
                        self.file_console.print(msg)
         | 
| 85 | 
            +
                    else:
         | 
| 86 | 
            +
                        kwargs.setdefault(
         | 
| 87 | 
            +
                            "stacklevel", 2
         | 
| 88 | 
            +
                        )  # Adjust stack level to show the actual caller
         | 
| 89 | 
            +
                        if "style" in kwargs:
         | 
| 90 | 
            +
                            kwargs.pop("style")
         | 
| 91 | 
            +
                        if "level" in kwargs:
         | 
| 92 | 
            +
                            kwargs.pop("level")
         | 
| 93 | 
            +
                        super().info(msg, *args, **kwargs)
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                def warning(self, msg, *args, **kwargs):
         | 
| 96 | 
            +
                    kwargs.setdefault("stacklevel", 2)
         | 
| 97 | 
            +
                    super().warning(msg, *args, **kwargs)
         | 
| 98 | 
            +
             | 
| 99 | 
            +
                def error(self, msg, *args, **kwargs):
         | 
| 100 | 
            +
                    kwargs.setdefault("stacklevel", 2)
         | 
| 101 | 
            +
                    super().error(msg, *args, **kwargs)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
                def critical(self, msg, *args, **kwargs):
         | 
| 104 | 
            +
                    kwargs.setdefault("stacklevel", 2)
         | 
| 105 | 
            +
                    super().critical(msg, *args, **kwargs)
         | 
| 106 | 
            +
             | 
| 107 | 
            +
                def debug(self, msg, *args, **kwargs):
         | 
| 108 | 
            +
                    kwargs.setdefault("stacklevel", 2)
         | 
| 109 | 
            +
                    super().debug(msg, *args, **kwargs)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                def log_error(self, error_message: str) -> None:
         | 
| 112 | 
            +
                    self.info(escape_code_brackets(error_message), style="bold red", level=LogLevel.ERROR)
         | 
| 113 | 
            +
             | 
| 114 | 
            +
                def log_markdown(self, content: str, title: str | None = None, level=LogLevel.INFO, style=YELLOW_HEX) -> None:
         | 
| 115 | 
            +
                    markdown_content = Syntax(
         | 
| 116 | 
            +
                        content,
         | 
| 117 | 
            +
                        lexer="markdown",
         | 
| 118 | 
            +
                        theme="github-dark",
         | 
| 119 | 
            +
                        word_wrap=True,
         | 
| 120 | 
            +
                    )
         | 
| 121 | 
            +
                    if title:
         | 
| 122 | 
            +
                        self.info(
         | 
| 123 | 
            +
                            Group(
         | 
| 124 | 
            +
                                Rule(
         | 
| 125 | 
            +
                                    "[bold italic]" + title,
         | 
| 126 | 
            +
                                    align="left",
         | 
| 127 | 
            +
                                    style=style,
         | 
| 128 | 
            +
                                ),
         | 
| 129 | 
            +
                                markdown_content,
         | 
| 130 | 
            +
                            ),
         | 
| 131 | 
            +
                            level=level,
         | 
| 132 | 
            +
                        )
         | 
| 133 | 
            +
                    else:
         | 
| 134 | 
            +
                        self.info(markdown_content, level=level)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                def log_code(self, title: str, content: str, level: int = LogLevel.INFO) -> None:
         | 
| 137 | 
            +
                    self.info(
         | 
| 138 | 
            +
                        Panel(
         | 
| 139 | 
            +
                            Syntax(
         | 
| 140 | 
            +
                                content,
         | 
| 141 | 
            +
                                lexer="python",
         | 
| 142 | 
            +
                                theme="monokai",
         | 
| 143 | 
            +
                                word_wrap=True,
         | 
| 144 | 
            +
                            ),
         | 
| 145 | 
            +
                            title="[bold]" + title,
         | 
| 146 | 
            +
                            title_align="left",
         | 
| 147 | 
            +
                            box=box.HORIZONTALS,
         | 
| 148 | 
            +
                        ),
         | 
| 149 | 
            +
                        level=level,
         | 
| 150 | 
            +
                    )
         | 
| 151 | 
            +
             | 
| 152 | 
            +
                def log_rule(self, title: str, level: int = LogLevel.INFO) -> None:
         | 
| 153 | 
            +
                    self.info(
         | 
| 154 | 
            +
                        Rule(
         | 
| 155 | 
            +
                            "[bold]" + title,
         | 
| 156 | 
            +
                            characters="β",
         | 
| 157 | 
            +
                            style=YELLOW_HEX,
         | 
| 158 | 
            +
                        ),
         | 
| 159 | 
            +
                        level=LogLevel.INFO,
         | 
| 160 | 
            +
                    )
         | 
| 161 | 
            +
             | 
| 162 | 
            +
                def log_task(self, content: str, subtitle: str, title: str | None = None, level: LogLevel = LogLevel.INFO) -> None:
         | 
| 163 | 
            +
                    self.info(
         | 
| 164 | 
            +
                        Panel(
         | 
| 165 | 
            +
                            f"\n[bold]{escape_code_brackets(content)}\n",
         | 
| 166 | 
            +
                            title="[bold]New run" + (f" - {title}" if title else ""),
         | 
| 167 | 
            +
                            subtitle=subtitle,
         | 
| 168 | 
            +
                            border_style=YELLOW_HEX,
         | 
| 169 | 
            +
                            subtitle_align="left",
         | 
| 170 | 
            +
                        ),
         | 
| 171 | 
            +
                        level=level,
         | 
| 172 | 
            +
                    )
         | 
| 173 | 
            +
             | 
| 174 | 
            +
                def log_messages(self, messages: list[dict], level: LogLevel = LogLevel.DEBUG) -> None:
         | 
| 175 | 
            +
                    messages_as_string = "\n".join([json.dumps(dict(message), indent=4, ensure_ascii=False) for message in messages])
         | 
| 176 | 
            +
                    self.info(
         | 
| 177 | 
            +
                        Syntax(
         | 
| 178 | 
            +
                            messages_as_string,
         | 
| 179 | 
            +
                            lexer="markdown",
         | 
| 180 | 
            +
                            theme="github-dark",
         | 
| 181 | 
            +
                            word_wrap=True,
         | 
| 182 | 
            +
                        ),
         | 
| 183 | 
            +
                        level=level,
         | 
| 184 | 
            +
                    )
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                def visualize_agent_tree(self, agent):
         | 
| 187 | 
            +
                    def create_tools_section(tools_dict):
         | 
| 188 | 
            +
                        table = Table(show_header=True, header_style="bold")
         | 
| 189 | 
            +
                        table.add_column("Name", style="#1E90FF")
         | 
| 190 | 
            +
                        table.add_column("Description")
         | 
| 191 | 
            +
                        table.add_column("Arguments")
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                        for name, tool in tools_dict.items():
         | 
| 194 | 
            +
                            args = [
         | 
| 195 | 
            +
                                f"{arg_name} (`{info.get('type', 'Any')}`{', optional' if info.get('optional') else ''}): {info.get('description', '')}"
         | 
| 196 | 
            +
                                for arg_name, info in getattr(tool, "inputs", {}).items()
         | 
| 197 | 
            +
                            ]
         | 
| 198 | 
            +
                            table.add_row(name, getattr(tool, "description", str(tool)), "\n".join(args))
         | 
| 199 | 
            +
             | 
| 200 | 
            +
                        return Group("π οΈ [italic #1E90FF]Tools:[/italic #1E90FF]", table)
         | 
| 201 | 
            +
             | 
| 202 | 
            +
                    def get_agent_headline(agent, name: str | None = None):
         | 
| 203 | 
            +
                        name_headline = f"{name} | " if name else ""
         | 
| 204 | 
            +
                        return f"[bold {YELLOW_HEX}]{name_headline}{agent.__class__.__name__} | {agent.model.model_id}"
         | 
| 205 | 
            +
             | 
| 206 | 
            +
                    def build_agent_tree(parent_tree, agent_obj):
         | 
| 207 | 
            +
                        """Recursively builds the agent tree."""
         | 
| 208 | 
            +
                        parent_tree.add(create_tools_section(agent_obj.tools))
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                        if agent_obj.managed_agents:
         | 
| 211 | 
            +
                            agents_branch = parent_tree.add("π€ [italic #1E90FF]Managed agents:")
         | 
| 212 | 
            +
                            for name, managed_agent in agent_obj.managed_agents.items():
         | 
| 213 | 
            +
                                agent_tree = agents_branch.add(get_agent_headline(managed_agent, name))
         | 
| 214 | 
            +
                                if managed_agent.__class__.__name__ == "CodeAgent":
         | 
| 215 | 
            +
                                    agent_tree.add(
         | 
| 216 | 
            +
                                        f"β
 [italic #1E90FF]Authorized imports:[/italic #1E90FF] {managed_agent.additional_authorized_imports}"
         | 
| 217 | 
            +
                                    )
         | 
| 218 | 
            +
                                agent_tree.add(f"π [italic #1E90FF]Description:[/italic #1E90FF] {managed_agent.description}")
         | 
| 219 | 
            +
                                build_agent_tree(agent_tree, managed_agent)
         | 
| 220 | 
            +
             | 
| 221 | 
            +
                    main_tree = Tree(get_agent_headline(agent))
         | 
| 222 | 
            +
                    if agent.__class__.__name__ == "CodeAgent":
         | 
| 223 | 
            +
                        main_tree.add(
         | 
| 224 | 
            +
                            f"β
 [italic #1E90FF]Authorized imports:[/italic #1E90FF] {agent.additional_authorized_imports}"
         | 
| 225 | 
            +
                        )
         | 
| 226 | 
            +
                    build_agent_tree(main_tree, agent)
         | 
| 227 | 
            +
                    self.console.print(main_tree)
         | 
| 228 | 
            +
             | 
| 229 | 
            +
            logger = AgentLogger()
         | 
    	
        src/utils/__init__.py
    ADDED
    
    | @@ -0,0 +1,8 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .path_utils import get_project_root, assemble_project_path
         | 
| 2 | 
            +
            from .singleton import Singleton
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            __all__ = [
         | 
| 5 | 
            +
                "get_project_root",
         | 
| 6 | 
            +
                "assemble_project_path",
         | 
| 7 | 
            +
                "Singleton"
         | 
| 8 | 
            +
            ]
         | 
    	
        src/utils/hf_utils.py
    ADDED
    
    | 
            File without changes
         | 
    	
        src/utils/path_utils.py
    ADDED
    
    | @@ -0,0 +1,12 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from pathlib import Path
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            def get_project_root():
         | 
| 5 | 
            +
                root = str(Path(__file__).resolve().parents[2])
         | 
| 6 | 
            +
                return root
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            def assemble_project_path(path):
         | 
| 9 | 
            +
                """Assemble a path relative to the project root directory"""
         | 
| 10 | 
            +
                if not os.path.isabs(path):
         | 
| 11 | 
            +
                    path = os.path.join(get_project_root(), path)
         | 
| 12 | 
            +
                return path
         | 
    	
        src/utils/singleton.py
    ADDED
    
    | @@ -0,0 +1,25 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            """A singleton metaclass for ensuring only one instance of a class."""
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import abc
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            class Singleton(abc.ABCMeta, type):
         | 
| 7 | 
            +
                """
         | 
| 8 | 
            +
                Singleton metaclass for ensuring only one instance of a class.
         | 
| 9 | 
            +
                """
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                _instances = {}
         | 
| 12 | 
            +
             | 
| 13 | 
            +
                def __call__(cls, *args, **kwargs):
         | 
| 14 | 
            +
                    """Call method for the singleton metaclass."""
         | 
| 15 | 
            +
                    if cls not in cls._instances:
         | 
| 16 | 
            +
                        cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
         | 
| 17 | 
            +
                    return cls._instances[cls]
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            class AbstractSingleton(abc.ABC, metaclass=Singleton):
         | 
| 21 | 
            +
                """
         | 
| 22 | 
            +
                Abstract singleton class for ensuring only one instance of a class.
         | 
| 23 | 
            +
                """
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                pass
         |