File size: 4,019 Bytes
569cdb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
#
# Licensed under the Apache License, Version 2.0 (the "License"
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License

from typing import Dict, List, Optional, TypedDict

import erniebot.utils.token_helper as token_helper


class Message:
    """The base class of a message."""

    def __init__(self, role: str, content: str, token_count: Optional[int] = None):
        self.role = role
        self.content = content
        self._token_count = token_count
        self._param_names = ["role", "content"]

    @property
    def token_count(self):
        """Get the number of tokens of the message."""
        if self._token_count is None:
            raise AttributeError("The token count of the message has not been set.")
        return self._token_count

    @token_count.setter
    def token_count(self, token_count: int):
        """Set the number of tokens of the message."""
        if self._token_count is not None:
            raise AttributeError("The token count of the message can only be set once.")
        self._token_count = token_count

    def to_dict(self) -> Dict[str, str]:
        res = {}
        for name in self._param_names:
            res[name] = getattr(self, name)
        return res

    def __str__(self) -> str:
        return f"<{self._get_attrs_str()}>"

    def __repr__(self):
        return f"<{self.__class__.__name__} {self._get_attrs_str()}>"

    def _get_attrs_str(self) -> str:
        parts: List[str] = []
        for name in self._param_names:
            value = getattr(self, name)
            if value is not None and value != "":
                parts.append(f"{name}: {repr(value)}")
        if self._token_count is not None:
            parts.append(f"token_count: {self._token_count}")
        return ", ".join(parts)


class SystemMessage(Message):
    """A message from a human to set system information."""

    def __init__(self, content: str):
        super().__init__(role="system", content=content, token_count=len(content))


class HumanMessage(Message):
    """A message from a human."""

    def __init__(self, content: str):
        super().__init__(role="user", content=content)


class FunctionCall(TypedDict):
    name: str
    thoughts: str
    arguments: str


class TokenUsage(TypedDict):
    prompt_tokens: int
    completion_tokens: int


class AIMessage(Message):
    """A message from an assistant."""

    def __init__(
        self,
        content: str,
        function_call: Optional[FunctionCall],
        token_usage: Optional[TokenUsage] = None,
    ):
        if token_usage is None:
            prompt_tokens = 0
            completion_tokens = token_helper.approx_num_tokens(content)
        else:
            prompt_tokens, completion_tokens = self._parse_token_count(token_usage)
        super().__init__(role="assistant", content=content, token_count=completion_tokens)
        self.function_call = function_call
        self.query_tokens_count = prompt_tokens
        self._param_names = ["role", "content", "function_call"]

    def _parse_token_count(self, token_usage: TokenUsage):
        """Parse the token count information from LLM."""
        return token_usage["prompt_tokens"], token_usage["completion_tokens"]


class FunctionMessage(Message):
    """A message from a human, containing the result of a function call."""

    def __init__(self, name: str, content: str):
        super().__init__(role="function", content=content)
        self.name = name
        self._param_names = ["role", "name", "content"]


class AIMessageChunk(AIMessage):
    """A message chunk from an assistant."""