# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations import os import uuid from typing import Any, Dict, List, Optional, Type from erniebot_agent.messages import AIMessage, HumanMessage, Message from erniebot_agent.tools.base import Tool from erniebot_agent.tools.schema import ToolParameterView from erniebot_agent.utils.common import download_file, get_cache_dir from pydantic import Field import erniebot class ImageGenerationInputView(ToolParameterView): prompt: str = Field(description="描述图像内容、风格的文本。例如:生成一张月亮的照片,月亮很圆。") width: int = Field(description="生成图片的宽度") height: int = Field(description="生成图片的高度") image_num: int = Field(description="生成图片的数量") class ImageGenerationOutputView(ToolParameterView): image_path: str = Field(description="图片在本地机器上的保存路径") class ImageGenerationTool(Tool): description: str = "AI作图、生成图片、画图的工具" input_type: Type[ToolParameterView] = ImageGenerationInputView ouptut_type: Type[ToolParameterView] = ImageGenerationOutputView def __init__( self, yinian_access_token: Optional[str] = None, yinian_ak: Optional[str] = None, yinian_sk: Optional[str] = None, ) -> None: self.config: Dict[str, Optional[Any]] if yinian_access_token is not None: self.config = {"api_type": "yinian", "access_token": yinian_access_token} elif yinian_ak is not None and yinian_sk is not None: self.config = {"api_type": "yinian", "ak": yinian_ak, "sk": yinian_sk} else: raise ValueError("Please set the yinian_access_token, or set yinian_ak and yinian_sk") async def __call__( self, prompt: str, width: int = 512, height: int = 512, image_num: int = 1, ) -> Dict[str, List[str]]: response = erniebot.Image.create( model="ernie-vilg-v2", prompt=prompt, width=width, height=height, image_num=image_num, _config_=self.config, ) image_path = [] cache_dir = get_cache_dir() for item in response["data"]["sub_task_result_list"]: image_url = item["final_image_list"][0]["img_url"] save_path = os.path.join(cache_dir, f"img_{uuid.uuid1()}.png") download_file(image_url, save_path) image_path.append(save_path) return {"image_path": image_path} @property def examples(self) -> List[Message]: return [ HumanMessage("画一张小狗的图片,图像高度512,图像宽度512"), AIMessage( "", function_call={ "name": "ImageGenerationTool", "thoughts": "用户需要我生成一张小狗的图片,图像高度为512,宽度为512。我可以使用ImageGenerationTool工具来满足用户的需求。", "arguments": '{"prompt":"画一张小狗的图片,图像高度512,图像宽度512",' '"width":512,"height":512,"image_num":1}', }, ), HumanMessage("生成两张天空的图片"), AIMessage( "", function_call={ "name": self.tool_name, "thoughts": "用户想要生成两张天空的图片,我需要调用ImageGenerationTool工具的call接口," "并设置prompt为'生成两张天空的图片',width和height可以默认为512,image_num默认为2。", "arguments": '{"prompt":"生成两张天空的图片","width":512,"height":512,"image_num":2}', }, ), HumanMessage("使用AI作图工具,生成1张小猫的图片,高度和高度是1024"), AIMessage( "", function_call={ "name": self.tool_name, "thoughts": "用户需要生成一张小猫的图片,高度和宽度都是1024。我可以使用ImageGenerationTool工具来满足用户的需求。", "arguments": '{"prompt":"生成一张小猫的照片。","width":1024,"height":1024,"image_num":1}', }, ), ]