Superlang commited on
Commit
a915193
·
1 Parent(s): 3b18a43

first init

Browse files
Files changed (6) hide show
  1. .gitignore +2 -0
  2. README.md +3 -0
  3. app.py +217 -0
  4. avatar.png +0 -0
  5. demo.png +0 -0
  6. requirements.txt +0 -0
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ /.idea/
2
+ /.img/
README.md CHANGED
@@ -11,3 +11,6 @@ license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+ ![image](demo.png)
16
+
app.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import json
4
+ import os
5
+ import os.path
6
+ import re
7
+ import time
8
+ from abc import ABC
9
+ from typing import Any
10
+ from uuid import uuid4
11
+
12
+ import gradio as gr
13
+ import requests
14
+ from PIL import Image
15
+ from langchain.agents import initialize_agent
16
+ from langchain.chat_models import AzureChatOpenAI
17
+ from langchain.memory import ConversationBufferWindowMemory
18
+ from langchain.tools import BaseTool
19
+
20
+ SAVE_FOLDER = "./img"
21
+ SDXL_API_KEY = "XXX"
22
+ SDXL_API_SECRET = "XXXX"
23
+ AZURE_END_POINT = "https://aimodelgpt.openai.azure.com"
24
+ AZURE_OPEN_KEY = "XXXX"
25
+
26
+
27
+ class SdxlImage(BaseTool, ABC):
28
+ name = "AI SDXL Image Generator"
29
+
30
+ description = 'use this tool when you need to generate images by using SDXL model, To use the tool, you must ' \
31
+ 'provide prompt parameters prompt, prompt is the description and number of the image, for example, ' \
32
+ 'if you want to generate two images about a cute cat, set prompt = a cute cat[SEP]2'
33
+
34
+ NEGATIVE_PROMPT = "worst quality, low quality, normal quality, lowres, watermark, monochrome, grayscale, ugly, " \
35
+ "blurry, Tan skin, dark skin, black skin, skin spots, skin blemishes, age spot, glans, " \
36
+ "disabled, distorted, bad anatomy, morbid, malformation, amputation, bad proportions, twins, " \
37
+ "missing body, fused body, extra head, poorly drawn face, bad eyes, deformed eye, unclear eyes, " \
38
+ "cross-eyed, long neck, malformed limbs, extra limbs, extra arms, missing arms, bad tongue, " \
39
+ "strange fingers, mutated hands, missing hands, poorly drawn hands, extra hands, fused hands, " \
40
+ "connected hand, bad hands, wrong fingers, missing fingers, extra fingers, 4 fingers, " \
41
+ "3 fingers, deformed hands, extra legs, bad legs, many legs, more than two legs, bad feet, " \
42
+ "wrong feet, extra feets,"
43
+
44
+ api_key: str
45
+ api_secret: str
46
+
47
+ # def __init__(self, api_key, api_secret):
48
+ # self.api_key = api_key
49
+ # self.api_secret = api_secret
50
+
51
+ def _run(
52
+ self,
53
+ prompt,
54
+ **kwargs: Any,
55
+ ) -> Any:
56
+ print(f"execute SDXL Image Tool {prompt}")
57
+ split_items = prompt.split("[SEP]")
58
+ number = 1
59
+ if len(split_items) > 1:
60
+ prompt, number = split_items
61
+ return self.generate_image(query=prompt, number=int(number))
62
+
63
+ def get_access_token(self):
64
+ """
65
+ 使用 API Key,Secret Key 获取access_token,替换下列示例中的应用API Key、应用Secret Key
66
+ """
67
+ url = f'https://aip.baidubce.com/oauth/2.0/token?grant_type=client_credentials&client_id={self.api_key}&client_secret={self.api_secret}'
68
+
69
+ payload = json.dumps("")
70
+ headers = {
71
+ 'Content-Type': 'application/json',
72
+ 'Accept': 'application/json'
73
+ }
74
+
75
+ response = requests.request("POST", url, headers=headers, data=payload)
76
+ return response.json().get("access_token")
77
+
78
+ def save_image(self, base64_string):
79
+ file_path = _id = str(uuid4()) + ".png"
80
+ image_data = base64.b64decode(base64_string)
81
+ image = Image.open(io.BytesIO(image_data))
82
+ if not os.path.exists(SAVE_FOLDER):
83
+ os.mkdir(SAVE_FOLDER)
84
+ image.save(os.path.join(SAVE_FOLDER, file_path))
85
+ return file_path
86
+
87
+ def generate_image(self, query: str, number: int = 1):
88
+ token = self.get_access_token()
89
+ url = "https://aip.baidubce.com/rpc/2.0/ai_custom/v1/wenxinworkshop/text2image/sd_xl?access_token=" + token
90
+
91
+ payload = json.dumps({
92
+ "prompt": query,
93
+ "negative_prompt": self.NEGATIVE_PROMPT,
94
+ "size": "768x1024",
95
+ "steps": 25,
96
+ "n": number,
97
+ "sampler_index": "DPM++ SDE Karras"
98
+ })
99
+ headers = {
100
+ 'Content-Type': 'application/json'
101
+ }
102
+ response = requests.request("POST", url, headers=headers, data=payload)
103
+
104
+ try:
105
+ if response and response.text:
106
+ data = json.loads(response.text)['data']
107
+ if data:
108
+ filenames = ",".join([self.save_image(sub_data['b64_image']) for sub_data in data])
109
+ return f"generate total {number} of the {query}, output is all the files {filenames}"
110
+ except Exception as err:
111
+ print(err)
112
+
113
+ return "failed to call tool, got error message"
114
+
115
+
116
+ class AgentBot:
117
+ def __init__(self):
118
+ chat_llm = AzureChatOpenAI(
119
+ azure_endpoint=AZURE_END_POINT,
120
+ openai_api_key=AZURE_OPEN_KEY,
121
+ deployment_name="gpt-35-turbo",
122
+ openai_api_version="2023-10-01-preview",
123
+ temperature=0.0
124
+ )
125
+ # initialize conversational memory
126
+ conversational_memory = ConversationBufferWindowMemory(
127
+ memory_key='chat_history',
128
+ k=5,
129
+ return_messages=True
130
+ )
131
+
132
+ tools = [SdxlImage(api_key=SDXL_API_KEY, api_secret=SDXL_API_SECRET)]
133
+
134
+ # initialize agent with tools
135
+ self.agent = initialize_agent(
136
+ agent='chat-conversational-react-description',
137
+ tools=tools,
138
+ llm=chat_llm,
139
+ verbose=True,
140
+ max_iterations=3,
141
+ early_stopping_method='generate',
142
+ memory=conversational_memory
143
+ )
144
+
145
+ def run(self, txt) -> str:
146
+ result = self.agent(txt)
147
+ return result["output"]
148
+
149
+ def clear(self):
150
+ self.agent.memory.clear()
151
+
152
+
153
+ bot = AgentBot()
154
+
155
+ block_css = """#col_container {width: 1000px; margin-left: auto; margin-right: auto;}
156
+ #chatbot {height: 520px; overflow: auto;}"""
157
+
158
+ with gr.Blocks(css=block_css) as demo:
159
+ gr.Markdown("<h3><center>ChatGPT LangChain</center></h3>")
160
+ gr.Markdown(
161
+ """
162
+ This LangChain GPT can generate SD-XL Image
163
+ """
164
+ )
165
+
166
+ with gr.Row() as input_raw:
167
+ with gr.Column(elem_id="col_container"):
168
+ chatbot = gr.Chatbot([],
169
+ elem_id="chatbot",
170
+ label="ChatBot LangChain for AIGC",
171
+ bubble_full_width=False,
172
+ avatar_images=(None, (os.path.join(os.path.dirname(__file__), "avatar.png"))),
173
+ )
174
+
175
+ msg = gr.Textbox()
176
+
177
+ with gr.Row():
178
+ with gr.Column(scale=0.10, min_width=0):
179
+ run = gr.Button("🏃‍♂️Run")
180
+ with gr.Column(scale=0.10, min_width=0):
181
+ clear = gr.Button("🔄Clear️")
182
+
183
+
184
+ def respond(message, chat_history):
185
+ # bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"])
186
+ bot_message = bot.run(message)
187
+ regx = r'\b[\w-]+\.png'
188
+ match_image = re.findall(regx, bot_message)
189
+ chat_history.append((message, bot_message))
190
+ if match_image:
191
+ for image in match_image:
192
+ image_path = os.path.join(SAVE_FOLDER, image)
193
+ chat_history.append(
194
+ (None, (image_path,)),
195
+ )
196
+ time.sleep(2)
197
+ return "", chat_history
198
+
199
+
200
+ def clearMessage():
201
+ # clear agent memory
202
+ bot.clear()
203
+
204
+ # execute action
205
+ msg.submit(respond, [msg, chatbot], [msg, chatbot])
206
+ run.click(respond, [msg, chatbot], [msg, chatbot])
207
+ clear.click(clearMessage)
208
+ clear.click(lambda: [], None, chatbot)
209
+
210
+ gr.Examples(
211
+ examples=["generate a image about a boy reading books using SDXL",
212
+ "generate two images about a gril in the classroom using SDXL",
213
+ ],
214
+ inputs=msg
215
+ )
216
+
217
+ demo.queue(concurrency_count=10).launch()
avatar.png ADDED
demo.png ADDED
requirements.txt ADDED
Binary file (170 Bytes). View file