|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from erniebot_agent.memory import Memory |
|
from erniebot_agent.messages import AIMessage, Message |
|
|
|
|
|
class LimitTokensMemory(Memory): |
|
"""This class controls max tokens less than max_token_limit. |
|
If tokens >= max_token_limit, pop message from memory. |
|
""" |
|
|
|
def __init__(self, max_token_limit=6000): |
|
super().__init__() |
|
self.max_token_limit = max_token_limit |
|
self.mem_token_count = 0 |
|
|
|
assert ( |
|
max_token_limit is None |
|
) or max_token_limit > 0, "max_token_limit should be None or positive integer, \ |
|
but got {max_token_limit}".format( |
|
max_token_limit=max_token_limit |
|
) |
|
|
|
def add_message(self, message: Message): |
|
super().add_message(message) |
|
|
|
|
|
if isinstance(message, AIMessage): |
|
self.prune_message() |
|
|
|
def prune_message(self): |
|
self.mem_token_count += self.msg_manager.messages[-1].token_count |
|
self.mem_token_count += self.msg_manager.messages[-2].token_count |
|
if self.max_token_limit is not None: |
|
while self.mem_token_count > self.max_token_limit: |
|
deleted_message = self.msg_manager.pop_message() |
|
self.mem_token_count -= deleted_message.token_count |
|
else: |
|
|
|
if len(self.get_messages()) == 0: |
|
raise RuntimeError( |
|
"The messsage is now empty. \ |
|
It indicates {} which takes up {} tokens and exeeded {} tokens.".format( |
|
deleted_message, len(deleted_message.content), self.max_token_limit |
|
) |
|
) |
|
|