import json from typing import Optional, Union from metagpt.config2 import config from metagpt.configs.llm_config import LLMType from metagpt.logs import logger from metagpt.provider.azure_openai_api import AzureOpenAILLM from metagpt.provider.constant import GENERAL_FUNCTION_SCHEMA from metagpt.provider.openai_api import OpenAILLM from metagpt.schema import Message OriginalLLM = OpenAILLM if config.llm.api_type == LLMType.OPENAI else AzureOpenAILLM class MockLLM(OriginalLLM): def __init__(self, allow_open_api_call): original_llm_config = ( config.get_openai_llm() if config.llm.api_type == LLMType.OPENAI else config.get_azure_llm() ) super().__init__(original_llm_config) self.allow_open_api_call = allow_open_api_call self.rsp_cache: dict = {} self.rsp_candidates: list[dict] = [] # a test can have multiple calls with the same llm, thus a list async def acompletion_text(self, messages: list[dict], stream=False, timeout=3) -> str: """Overwrite original acompletion_text to cancel retry""" if stream: resp = await self._achat_completion_stream(messages, timeout=timeout) return resp rsp = await self._achat_completion(messages, timeout=timeout) return self.get_choice_text(rsp) async def original_aask( self, msg: Union[str, list[dict[str, str]]], system_msgs: Optional[list[str]] = None, format_msgs: Optional[list[dict[str, str]]] = None, images: Optional[Union[str, list[str]]] = None, timeout=3, stream=True, ) -> str: if system_msgs: message = self._system_msgs(system_msgs) else: message = [self._default_system_msg()] if not self.use_system_prompt: message = [] if format_msgs: message.extend(format_msgs) if isinstance(msg, str): message.append(self._user_msg(msg, images=images)) else: message.extend(msg) logger.debug(message) rsp = await self.acompletion_text(message, stream=stream, timeout=timeout) return rsp async def original_aask_batch(self, msgs: list, timeout=3) -> str: """A copy of metagpt.provider.base_llm.BaseLLM.aask_batch, we can't use super().aask because it will be mocked""" context = [] for msg in msgs: umsg = self._user_msg(msg) context.append(umsg) rsp_text = await self.acompletion_text(context, timeout=timeout) context.append(self._assistant_msg(rsp_text)) return self._extract_assistant_rsp(context) async def original_aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict: """ A copy of metagpt.provider.openai_api.OpenAILLM.aask_code, we can't use super().aask because it will be mocked. Since openai_api.OpenAILLM.aask_code is different from base_llm.BaseLLM.aask_code, we use the former. """ if "tools" not in kwargs: configs = {"tools": [{"type": "function", "function": GENERAL_FUNCTION_SCHEMA}]} kwargs.update(configs) rsp = await self._achat_completion_function(messages, **kwargs) return self.get_choice_function_arguments(rsp) async def aask( self, msg: Union[str, list[dict[str, str]]], system_msgs: Optional[list[str]] = None, format_msgs: Optional[list[dict[str, str]]] = None, images: Optional[Union[str, list[str]]] = None, timeout=3, stream=True, ) -> str: # used to identify it a message has been called before if isinstance(msg, list): msg_key = "#MSG_SEP#".join([m["content"] for m in msg]) else: msg_key = msg if system_msgs: joined_system_msg = "#MSG_SEP#".join(system_msgs) + "#SYSTEM_MSG_END#" msg_key = joined_system_msg + msg_key rsp = await self._mock_rsp(msg_key, self.original_aask, msg, system_msgs, format_msgs, images, timeout, stream) return rsp async def aask_batch(self, msgs: list, timeout=3) -> str: msg_key = "#MSG_SEP#".join([msg if isinstance(msg, str) else msg.content for msg in msgs]) rsp = await self._mock_rsp(msg_key, self.original_aask_batch, msgs, timeout) return rsp async def aask_code(self, messages: Union[str, Message, list[dict]], **kwargs) -> dict: msg_key = json.dumps(self.format_msg(messages), ensure_ascii=False) rsp = await self._mock_rsp(msg_key, self.original_aask_code, messages, **kwargs) return rsp async def _mock_rsp(self, msg_key, ask_func, *args, **kwargs): if msg_key not in self.rsp_cache: if not self.allow_open_api_call: raise ValueError( "In current test setting, api call is not allowed, you should properly mock your tests, " "or add expected api response in tests/data/rsp_cache.json. " ) # Call the original unmocked method rsp = await ask_func(*args, **kwargs) else: logger.warning("Use response cache") rsp = self.rsp_cache[msg_key] self.rsp_candidates.append({msg_key: rsp}) return rsp