markqiu's picture
百度文心一言的例子
569cdb0
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union
from erniebot_agent.messages import AIMessage, Message, SystemMessage
class MessageManager:
"""
Messages Manager.
"""
def __init__(self) -> None:
self.messages: List[Message] = []
self._system_message: Union[SystemMessage, None] = None
@property
def system_message(self) -> Optional[Message]:
"""
The message manager have only one system message.
return: Message or None
"""
return self._system_message
@system_message.setter
def system_message(self, message: SystemMessage) -> None:
if self._system_message is not None:
Warning("system message has been set, the previous one will be replaced")
self._system_message = message
def add_messages(self, messages: List[Message]) -> None:
self.messages.extend(messages)
def add_message(self, message: Message) -> None:
if isinstance(message, SystemMessage):
self.system_message = message
else:
self.messages.append(message)
def pop_message(self) -> Message:
return self.messages.pop(0)
def clear_messages(self) -> None:
self.messages = []
def update_last_message_token_count(self, token_count: int):
if token_count == 0:
self.messages[-1].token_count = len(self.messages[-1].content)
else:
self.messages[-1].token_count = token_count
def retrieve_messages(self) -> List[Message]:
return self.messages
class Memory:
"""The base class of memory"""
def __init__(self):
self.msg_manager = MessageManager()
def add_messages(self, messages: List[Message]):
for message in messages:
self.add_message(message)
def add_message(self, message: Message):
if isinstance(message, AIMessage):
self.msg_manager.update_last_message_token_count(message.query_tokens_count)
self.msg_manager.add_message(message)
def get_messages(self) -> List[Message]:
return self.msg_manager.retrieve_messages()
def get_system_message(self) -> SystemMessage:
return self.msg_manager.system_message
def clear_chat_history(self):
self.msg_manager.clear_messages()
class WholeMemory(Memory):
"""The memory include all the messages"""
def __init__(self):
super().__init__()