File size: 4,019 Bytes
569cdb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from typing import Dict, List, Optional, TypedDict
import erniebot.utils.token_helper as token_helper
class Message:
"""The base class of a message."""
def __init__(self, role: str, content: str, token_count: Optional[int] = None):
self.role = role
self.content = content
self._token_count = token_count
self._param_names = ["role", "content"]
@property
def token_count(self):
"""Get the number of tokens of the message."""
if self._token_count is None:
raise AttributeError("The token count of the message has not been set.")
return self._token_count
@token_count.setter
def token_count(self, token_count: int):
"""Set the number of tokens of the message."""
if self._token_count is not None:
raise AttributeError("The token count of the message can only be set once.")
self._token_count = token_count
def to_dict(self) -> Dict[str, str]:
res = {}
for name in self._param_names:
res[name] = getattr(self, name)
return res
def __str__(self) -> str:
return f"<{self._get_attrs_str()}>"
def __repr__(self):
return f"<{self.__class__.__name__} {self._get_attrs_str()}>"
def _get_attrs_str(self) -> str:
parts: List[str] = []
for name in self._param_names:
value = getattr(self, name)
if value is not None and value != "":
parts.append(f"{name}: {repr(value)}")
if self._token_count is not None:
parts.append(f"token_count: {self._token_count}")
return ", ".join(parts)
class SystemMessage(Message):
"""A message from a human to set system information."""
def __init__(self, content: str):
super().__init__(role="system", content=content, token_count=len(content))
class HumanMessage(Message):
"""A message from a human."""
def __init__(self, content: str):
super().__init__(role="user", content=content)
class FunctionCall(TypedDict):
name: str
thoughts: str
arguments: str
class TokenUsage(TypedDict):
prompt_tokens: int
completion_tokens: int
class AIMessage(Message):
"""A message from an assistant."""
def __init__(
self,
content: str,
function_call: Optional[FunctionCall],
token_usage: Optional[TokenUsage] = None,
):
if token_usage is None:
prompt_tokens = 0
completion_tokens = token_helper.approx_num_tokens(content)
else:
prompt_tokens, completion_tokens = self._parse_token_count(token_usage)
super().__init__(role="assistant", content=content, token_count=completion_tokens)
self.function_call = function_call
self.query_tokens_count = prompt_tokens
self._param_names = ["role", "content", "function_call"]
def _parse_token_count(self, token_usage: TokenUsage):
"""Parse the token count information from LLM."""
return token_usage["prompt_tokens"], token_usage["completion_tokens"]
class FunctionMessage(Message):
"""A message from a human, containing the result of a function call."""
def __init__(self, name: str, content: str):
super().__init__(role="function", content=content)
self.name = name
self._param_names = ["role", "name", "content"]
class AIMessageChunk(AIMessage):
"""A message chunk from an assistant."""
|