File size: 4,991 Bytes
569cdb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import os
import uuid
from typing import Any, Dict, List, Optional, Type

from erniebot_agent.messages import AIMessage, HumanMessage, Message
from erniebot_agent.tools.base import Tool
from erniebot_agent.tools.schema import ToolParameterView
from erniebot_agent.utils.common import download_file, get_cache_dir
from pydantic import Field

import erniebot


class ImageGenerationInputView(ToolParameterView):
    prompt: str = Field(description="描述图像内容、风格的文本。例如:生成一张月亮的照片,月亮很圆。")
    width: int = Field(description="生成图片的宽度")
    height: int = Field(description="生成图片的高度")
    image_num: int = Field(description="生成图片的数量")


class ImageGenerationOutputView(ToolParameterView):
    image_path: str = Field(description="图片在本地机器上的保存路径")


class ImageGenerationTool(Tool):
    description: str = "AI作图、生成图片、画图的工具"
    input_type: Type[ToolParameterView] = ImageGenerationInputView
    ouptut_type: Type[ToolParameterView] = ImageGenerationOutputView

    def __init__(
        self,
        yinian_access_token: Optional[str] = None,
        yinian_ak: Optional[str] = None,
        yinian_sk: Optional[str] = None,
    ) -> None:
        self.config: Dict[str, Optional[Any]]
        if yinian_access_token is not None:
            self.config = {"api_type": "yinian", "access_token": yinian_access_token}
        elif yinian_ak is not None and yinian_sk is not None:
            self.config = {"api_type": "yinian", "ak": yinian_ak, "sk": yinian_sk}
        else:
            raise ValueError("Please set the yinian_access_token, or set yinian_ak and yinian_sk")

    async def __call__(
        self,
        prompt: str,
        width: int = 512,
        height: int = 512,
        image_num: int = 1,
    ) -> Dict[str, List[str]]:
        response = erniebot.Image.create(
            model="ernie-vilg-v2",
            prompt=prompt,
            width=width,
            height=height,
            image_num=image_num,
            _config_=self.config,
        )

        image_path = []
        cache_dir = get_cache_dir()
        for item in response["data"]["sub_task_result_list"]:
            image_url = item["final_image_list"][0]["img_url"]
            save_path = os.path.join(cache_dir, f"img_{uuid.uuid1()}.png")
            download_file(image_url, save_path)
            image_path.append(save_path)
        return {"image_path": image_path}

    @property
    def examples(self) -> List[Message]:
        return [
            HumanMessage("画一张小狗的图片,图像高度512,图像宽度512"),
            AIMessage(
                "",
                function_call={
                    "name": "ImageGenerationTool",
                    "thoughts": "用户需要我生成一张小狗的图片,图像高度为512,宽度为512。我可以使用ImageGenerationTool工具来满足用户的需求。",
                    "arguments": '{"prompt":"画一张小狗的图片,图像高度512,图像宽度512",'
                    '"width":512,"height":512,"image_num":1}',
                },
            ),
            HumanMessage("生成两张天空的图片"),
            AIMessage(
                "",
                function_call={
                    "name": self.tool_name,
                    "thoughts": "用户想要生成两张天空的图片,我需要调用ImageGenerationTool工具的call接口,"
                    "并设置prompt为'生成两张天空的图片',width和height可以默认为512,image_num默认为2。",
                    "arguments": '{"prompt":"生成两张天空的图片","width":512,"height":512,"image_num":2}',
                },
            ),
            HumanMessage("使用AI作图工具,生成1张小猫的图片,高度和高度是1024"),
            AIMessage(
                "",
                function_call={
                    "name": self.tool_name,
                    "thoughts": "用户需要生成一张小猫的图片,高度和宽度都是1024。我可以使用ImageGenerationTool工具来满足用户的需求。",
                    "arguments": '{"prompt":"生成一张小猫的照片。","width":1024,"height":1024,"image_num":1}',
                },
            ),
        ]