Spaces:
Running
Running
File size: 4,609 Bytes
fe5c39d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 |
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@Time : 2024/1/4 16:32
@Author : alexanderwu
@File : context.py
"""
import os
from pathlib import Path
from typing import Any, Dict, Optional
from pydantic import BaseModel, ConfigDict
from metagpt.config2 import Config
from metagpt.configs.llm_config import LLMConfig, LLMType
from metagpt.provider.base_llm import BaseLLM
from metagpt.provider.llm_provider_registry import create_llm_instance
from metagpt.utils.cost_manager import (
CostManager,
FireworksCostManager,
TokenCostManager,
)
from metagpt.utils.git_repository import GitRepository
from metagpt.utils.project_repo import ProjectRepo
class AttrDict(BaseModel):
"""A dict-like object that allows access to keys as attributes, compatible with Pydantic."""
model_config = ConfigDict(extra="allow")
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.__dict__.update(kwargs)
def __getattr__(self, key):
return self.__dict__.get(key, None)
def __setattr__(self, key, value):
self.__dict__[key] = value
def __delattr__(self, key):
if key in self.__dict__:
del self.__dict__[key]
else:
raise AttributeError(f"No such attribute: {key}")
def set(self, key, val: Any):
self.__dict__[key] = val
def get(self, key, default: Any = None):
return self.__dict__.get(key, default)
def remove(self, key):
if key in self.__dict__:
self.__delattr__(key)
class Context(BaseModel):
"""Env context for MetaGPT"""
model_config = ConfigDict(arbitrary_types_allowed=True)
kwargs: AttrDict = AttrDict()
config: Config = Config.default()
repo: Optional[ProjectRepo] = None
git_repo: Optional[GitRepository] = None
src_workspace: Optional[Path] = None
cost_manager: CostManager = CostManager()
_llm: Optional[BaseLLM] = None
def new_environ(self):
"""Return a new os.environ object"""
env = os.environ.copy()
# i = self.options
# env.update({k: v for k, v in i.items() if isinstance(v, str)})
return env
def _select_costmanager(self, llm_config: LLMConfig) -> CostManager:
"""Return a CostManager instance"""
if llm_config.api_type == LLMType.FIREWORKS:
return FireworksCostManager()
elif llm_config.api_type == LLMType.OPEN_LLM:
return TokenCostManager()
else:
return self.cost_manager
def llm(self) -> BaseLLM:
"""Return a LLM instance, fixme: support cache"""
# if self._llm is None:
self._llm = create_llm_instance(self.config.llm)
if self._llm.cost_manager is None:
self._llm.cost_manager = self._select_costmanager(self.config.llm)
return self._llm
def llm_with_cost_manager_from_llm_config(self, llm_config: LLMConfig) -> BaseLLM:
"""Return a LLM instance, fixme: support cache"""
# if self._llm is None:
llm = create_llm_instance(llm_config)
if llm.cost_manager is None:
llm.cost_manager = self._select_costmanager(llm_config)
return llm
def serialize(self) -> Dict[str, Any]:
"""Serialize the object's attributes into a dictionary.
Returns:
Dict[str, Any]: A dictionary containing serialized data.
"""
return {
"workdir": str(self.repo.workdir) if self.repo else "",
"kwargs": {k: v for k, v in self.kwargs.__dict__.items()},
"cost_manager": self.cost_manager.model_dump_json(),
}
def deserialize(self, serialized_data: Dict[str, Any]):
"""Deserialize the given serialized data and update the object's attributes accordingly.
Args:
serialized_data (Dict[str, Any]): A dictionary containing serialized data.
"""
if not serialized_data:
return
workdir = serialized_data.get("workdir")
if workdir:
self.git_repo = GitRepository(local_path=workdir, auto_init=True)
self.repo = ProjectRepo(self.git_repo)
src_workspace = self.git_repo.workdir / self.git_repo.workdir.name
if src_workspace.exists():
self.src_workspace = src_workspace
kwargs = serialized_data.get("kwargs")
if kwargs:
for k, v in kwargs.items():
self.kwargs.set(k, v)
cost_manager = serialized_data.get("cost_manager")
if cost_manager:
self.cost_manager.model_validate_json(cost_manager)
|