yizhangliu commited on
Commit
b788590
·
1 Parent(s): 655b569

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -42
app.py CHANGED
@@ -1,7 +1,8 @@
1
  from transformers import pipeline
2
  import gradio as gr
3
  import random
4
- import paddlehub as hub
 
5
  import torch
6
  from transformers import AutoModelForCausalLM, AutoTokenizer
7
  from loguru import logger
@@ -23,23 +24,6 @@ def getTextTrans(text, source='zh', target='en'):
23
  except Exception as e:
24
  return text
25
 
26
- extend_prompt_pipe = pipeline('text-generation', model='yizhangliu/prompt-extend', max_length=77, pad_token_id=0)
27
-
28
- def load_prompter():
29
- prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
30
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
31
- tokenizer.pad_token = tokenizer.eos_token
32
- tokenizer.padding_side = "left"
33
- return prompter_model, tokenizer
34
- prompter_model, prompter_tokenizer = load_prompter()
35
- def extend_prompt_microsoft(in_text):
36
- input_ids = prompter_tokenizer(in_text.strip()+" Rephrase:", return_tensors="pt").input_ids
37
- eos_id = prompter_tokenizer.eos_token_id
38
- outputs = prompter_model.generate(input_ids, do_sample=False, max_new_tokens=75, num_beams=8, num_return_sequences=8, eos_token_id=eos_id, pad_token_id=eos_id, length_penalty=-1.0)
39
- output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
40
- res = output_texts[0].replace(in_text+" Rephrase:", "").strip()
41
- return res
42
-
43
  space_ids = {
44
  "spaces/stabilityai/stable-diffusion": "SD 2.1",
45
  "spaces/runwayml/stable-diffusion-v1-5": "SD 1.5",
@@ -50,9 +34,56 @@ space_ids = {
50
  tab_actions = []
51
  tab_titles = []
52
 
 
 
 
 
53
  thanks_info = "Thanks: "
54
- thanks_info += "[<a style='display:inline-block' href='https://huggingface.co/spaces/daspartho/prompt-extend' _blank><font style='color:blue;weight:bold;'>prompt-extend</font></a>]"
55
- thanks_info += "[<a style='display:inline-block' href='https://huggingface.co/spaces/microsoft/Promptist' _blank><font style='color:blue;weight:bold;'>Promptist</font></a>]"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  for space_id in space_ids.keys():
58
  print(space_id, space_ids[space_id])
@@ -94,6 +125,8 @@ start_work = """async() => {
94
  valueSetter.call(element, value);
95
  }
96
  }
 
 
97
  var gradioEl = document.querySelector('body > gradio-app').shadowRoot;
98
  if (!gradioEl) {
99
  gradioEl = document.querySelector('body > gradio-app');
@@ -103,7 +136,6 @@ start_work = """async() => {
103
  window['gradioEl'] = gradioEl;
104
  tabitems = window['gradioEl'].querySelectorAll('.tabitem');
105
  tabitems_title = window['gradioEl'].querySelectorAll('#tab_demo')[0].children[0].children[0].children;
106
-
107
  for (var i = 0; i < tabitems.length; i++) {
108
  if (tabitems_title[i].innerText.indexOf('SD') >= 0) {
109
  tabitems[i].childNodes[0].children[0].style.display='none';
@@ -113,9 +145,11 @@ start_work = """async() => {
113
  }
114
  }
115
  } else if (tabitems_title[i].innerText.indexOf('Taiyi') >= 0) {
116
- tabitems[3].children[0].children[0].children[1].style.display='none';
117
  tabitems[i].children[0].children[0].children[0].children[0].children[1].style.display='none';
118
- }
 
 
119
  }
120
 
121
  tab_demo = window['gradioEl'].querySelectorAll('#tab_demo')[0];
@@ -124,17 +158,27 @@ start_work = """async() => {
124
  const page1 = window['gradioEl'].querySelectorAll('#page_1')[0];
125
  const page2 = window['gradioEl'].querySelectorAll('#page_2')[0];
126
 
127
- btns_1 = window['gradioEl'].querySelector('#input_col1_row2').children;
128
  btns_1_split = 100 / btns_1.length;
129
  for (var i = 0; i < btns_1.length; i++) {
130
  btns_1[i].setAttribute('style', 'min-width:0px;width:' + btns_1_split + '%;');
131
  }
132
  page1.style.display = "none";
133
- page2.style.display = "block";
 
 
 
 
 
134
  window['prevPrompt'] = '';
135
  window['doCheckPrompt'] = 0;
136
  window['checkPrompt'] = function checkPrompt() {
137
  try {
 
 
 
 
 
138
  text_value = window['gradioEl'].querySelectorAll('#prompt_work')[0].querySelectorAll('textarea')[0].value;
139
  progress_bar = window['gradioEl'].querySelectorAll('.progress-bar');
140
  if (window['doCheckPrompt'] === 0 && window['prevPrompt'] !== text_value && progress_bar.length == 0) {
@@ -176,18 +220,43 @@ start_work = """async() => {
176
  return false;
177
  }"""
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  def prompt_extend(prompt, PM):
180
  prompt_en = getTextTrans(prompt, source='zh', target='en')
181
  if PM == 1:
182
  extend_prompt_en = extend_prompt_pipe(prompt_en+',', num_return_sequences=1)[0]["generated_text"]
183
- else:
184
  extend_prompt_en = extend_prompt_microsoft(prompt_en)
 
 
185
 
186
  if (prompt != prompt_en):
187
- logger.info(f"extend_prompt__1_[{PM}]_")
188
  extend_prompt_out = getTextTrans(extend_prompt_en, source='en', target='zh')
189
  else:
190
- logger.info(f"extend_prompt__2_[{PM}]_")
191
  extend_prompt_out = extend_prompt_en
192
 
193
  return extend_prompt_out
@@ -200,7 +269,24 @@ def prompt_extend_2(prompt):
200
  extend_prompt_out = prompt_extend(prompt, 2)
201
  return extend_prompt_out
202
 
203
- def prompt_draw(prompt):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
204
  prompt_en = getTextTrans(prompt, source='zh', target='en')
205
  if (prompt != prompt_en):
206
  logger.info(f"draw_prompt______1__")
@@ -208,9 +294,8 @@ def prompt_draw(prompt):
208
  else:
209
  logger.info(f"draw_prompt______2__")
210
  prompt_zh = getTextTrans(prompt, source='en', target='zh')
211
-
212
  return prompt_en, prompt_zh
213
-
214
  with gr.Blocks(title='Text-to-Image') as demo:
215
  with gr.Group(elem_id="page_1", visible=True) as page_1:
216
  with gr.Box():
@@ -224,25 +309,49 @@ with gr.Blocks(title='Text-to-Image') as demo:
224
  with gr.Row(elem_id="input_col1_row1"):
225
  prompt_input0 = gr.Textbox(lines=2, label="Original prompt", visible=True)
226
  with gr.Row(elem_id="input_col1_row2"):
 
 
227
  with gr.Column(elem_id="input_col1_row2_col0"):
228
  draw_btn_0 = gr.Button(value = "Generate(original)", elem_id="draw-btn-0")
229
- with gr.Column(elem_id="input_col1_row2_col1"):
230
- extend_btn_1 = gr.Button(value = "Extend_1",elem_id="extend-btn-1")
231
- with gr.Column(elem_id="input_col1_row2_col2"):
232
- extend_btn_2 = gr.Button(value = "Extend_2",elem_id="extend-btn-2")
 
 
 
 
 
233
  with gr.Column(id="input_col2"):
234
  prompt_input1 = gr.Textbox(lines=2, label="Extend prompt", visible=True)
235
  draw_btn_1 = gr.Button(value = "Generate(extend)", elem_id="draw-btn-1")
236
- prompt_work = gr.Textbox(lines=1, label="prompt_work", elem_id="prompt_work", visible=False)
237
- prompt_work_zh = gr.Textbox(lines=1, label="prompt_work_zh", elem_id="prompt_work_zh", visible=False)
 
 
238
  with gr.Row(elem_id='tab_demo', visible=True).style(height=200):
239
  tab_demo = gr.TabbedInterface(tab_actions, tab_titles)
240
- with gr.Row():
 
 
 
 
 
 
241
  gr.HTML(f"<p>{thanks_info}</p>")
242
 
243
- extend_btn_1.click(fn=prompt_extend_1, inputs=[prompt_input0], outputs=[prompt_input1])
244
- extend_btn_2.click(fn=prompt_extend_2, inputs=[prompt_input0], outputs=[prompt_input1])
245
- draw_btn_0.click(fn=prompt_draw, inputs=[prompt_input0], outputs=[prompt_work, prompt_work_zh])
246
- draw_btn_1.click(fn=prompt_draw, inputs=[prompt_input1], outputs=[prompt_work, prompt_work_zh])
 
 
 
 
 
 
 
 
 
247
 
248
  demo.launch()
 
1
  from transformers import pipeline
2
  import gradio as gr
3
  import random
4
+ import string
5
+ import paddlehub as hub
6
  import torch
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
  from loguru import logger
 
24
  except Exception as e:
25
  return text
26
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  space_ids = {
28
  "spaces/stabilityai/stable-diffusion": "SD 2.1",
29
  "spaces/runwayml/stable-diffusion-v1-5": "SD 1.5",
 
34
  tab_actions = []
35
  tab_titles = []
36
 
37
+ extend_prompt_1 = True
38
+ extend_prompt_2 = True
39
+ extend_prompt_3 = True
40
+
41
  thanks_info = "Thanks: "
42
+ if extend_prompt_1:
43
+ extend_prompt_pipe = pipeline('text-generation', model='yizhangliu/prompt-extend', max_length=77, pad_token_id=0)
44
+ thanks_info += "[<a style='display:inline-block' href='https://huggingface.co/spaces/daspartho/prompt-extend' _blank><font style='color:blue;weight:bold;'>prompt-extend(1)</font></a>]"
45
+ if extend_prompt_2:
46
+ def load_prompter():
47
+ prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
48
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
49
+ tokenizer.pad_token = tokenizer.eos_token
50
+ tokenizer.padding_side = "left"
51
+ return prompter_model, tokenizer
52
+ prompter_model, prompter_tokenizer = load_prompter()
53
+ def extend_prompt_microsoft(in_text):
54
+ input_ids = prompter_tokenizer(in_text.strip()+" Rephrase:", return_tensors="pt").input_ids
55
+ eos_id = prompter_tokenizer.eos_token_id
56
+ outputs = prompter_model.generate(input_ids, do_sample=False, max_new_tokens=75, num_beams=8, num_return_sequences=8, eos_token_id=eos_id, pad_token_id=eos_id, length_penalty=-1.0)
57
+ output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
58
+ res = output_texts[0].replace(in_text+" Rephrase:", "").strip()
59
+ return res
60
+ thanks_info += "[<a style='display:inline-block' href='https://huggingface.co/spaces/microsoft/Promptist' _blank><font style='color:blue;weight:bold;'>Promptist(2)</font></a>]"
61
+ if extend_prompt_3:
62
+ MagicPrompt = gr.Interface.load("spaces/Gustavosta/MagicPrompt-Stable-Diffusion")
63
+ thanks_info += "[<a style='display:inline-block' href='https://huggingface.co/spaces/Gustavosta/MagicPrompt-Stable-Diffusion' _blank><font style='color:blue;weight:bold;'>MagicPrompt(3)</font></a>]"
64
+
65
+ do_dreamlike_photoreal = True
66
+ if do_dreamlike_photoreal:
67
+ def add_random_noise(prompt, noise_level=0.1):
68
+ # Get the percentage of characters to add as noise
69
+ percentage_noise = noise_level * 5
70
+ # Get the number of characters to add as noise
71
+ num_noise_chars = int(len(prompt) * (percentage_noise/100))
72
+ # Get the indices of the characters to add noise to
73
+ noise_indices = random.sample(range(len(prompt)), num_noise_chars)
74
+ # Add noise to the selected characters
75
+ prompt_list = list(prompt)
76
+ for index in noise_indices:
77
+ prompt_list[index] = random.choice(string.ascii_letters + string.punctuation)
78
+ new_prompt = "".join(prompt_list)
79
+ return new_prompt
80
+
81
+ dreamlike_photoreal_2_0 = gr.Interface.load("models/dreamlike-art/dreamlike-photoreal-2.0")
82
+ dreamlike_image = gr.Image(label="Dreamlike Photoreal 2.0")
83
+
84
+ tab_actions.append(dreamlike_image)
85
+ tab_titles.append("Dreamlike_2.0")
86
+ thanks_info += "[<a style='display:inline-block' href='https://huggingface.co/models/dreamlike-art/dreamlike-photoreal-2.0' _blank><font style='color:blue;weight:bold;'>dreamlike-photoreal-2.0</font></a>]"
87
 
88
  for space_id in space_ids.keys():
89
  print(space_id, space_ids[space_id])
 
125
  valueSetter.call(element, value);
126
  }
127
  }
128
+ window['tab_advanced'] = 0;
129
+
130
  var gradioEl = document.querySelector('body > gradio-app').shadowRoot;
131
  if (!gradioEl) {
132
  gradioEl = document.querySelector('body > gradio-app');
 
136
  window['gradioEl'] = gradioEl;
137
  tabitems = window['gradioEl'].querySelectorAll('.tabitem');
138
  tabitems_title = window['gradioEl'].querySelectorAll('#tab_demo')[0].children[0].children[0].children;
 
139
  for (var i = 0; i < tabitems.length; i++) {
140
  if (tabitems_title[i].innerText.indexOf('SD') >= 0) {
141
  tabitems[i].childNodes[0].children[0].style.display='none';
 
145
  }
146
  }
147
  } else if (tabitems_title[i].innerText.indexOf('Taiyi') >= 0) {
148
+ tabitems[i].children[0].children[0].children[1].style.display='none';
149
  tabitems[i].children[0].children[0].children[0].children[0].children[1].style.display='none';
150
+ } else if (tabitems_title[i].innerText.indexOf('Dreamlike') >= 0) {
151
+ tabitems[i].childNodes[0].children[0].children[1].style.display='none';
152
+ }
153
  }
154
 
155
  tab_demo = window['gradioEl'].querySelectorAll('#tab_demo')[0];
 
158
  const page1 = window['gradioEl'].querySelectorAll('#page_1')[0];
159
  const page2 = window['gradioEl'].querySelectorAll('#page_2')[0];
160
 
161
+ btns_1 = window['gradioEl'].querySelector('#input_col1_row3').children;
162
  btns_1_split = 100 / btns_1.length;
163
  for (var i = 0; i < btns_1.length; i++) {
164
  btns_1[i].setAttribute('style', 'min-width:0px;width:' + btns_1_split + '%;');
165
  }
166
  page1.style.display = "none";
167
+ page2.style.display = "block";
168
+ prompt_work = window['gradioEl'].querySelectorAll('#prompt_work');
169
+ for (var i = 0; i < prompt_work.length; i++) {
170
+ prompt_work[i].style.display='none';
171
+ }
172
+
173
  window['prevPrompt'] = '';
174
  window['doCheckPrompt'] = 0;
175
  window['checkPrompt'] = function checkPrompt() {
176
  try {
177
+ prompt_work = window['gradioEl'].querySelectorAll('#prompt_work');
178
+ if (prompt_work.length > 0 && prompt_work[0].children.length > 1) {
179
+ prompt_work[0].children[1].style.display='none';
180
+ prompt_work[0].style.display='block';
181
+ }
182
  text_value = window['gradioEl'].querySelectorAll('#prompt_work')[0].querySelectorAll('textarea')[0].value;
183
  progress_bar = window['gradioEl'].querySelectorAll('.progress-bar');
184
  if (window['doCheckPrompt'] === 0 && window['prevPrompt'] !== text_value && progress_bar.length == 0) {
 
220
  return false;
221
  }"""
222
 
223
+ switch_tab_advanced = """async() => {
224
+ window['tab_advanced'] = 1 - window['tab_advanced'];
225
+ if (window['tab_advanced']==0) {
226
+ action = 'none';
227
+ } else {
228
+ action = 'block';
229
+ }
230
+ tabitems = window['gradioEl'].querySelectorAll('.tabitem');
231
+ tabitems_title = window['gradioEl'].querySelectorAll('#tab_demo')[0].children[0].children[0].children;
232
+ for (var i = 0; i < tabitems.length; i++) {
233
+ if (tabitems_title[i].innerText.indexOf('SD') >= 0) {
234
+ for (var j = 0; j < tabitems[i].childNodes[0].children[1].children.length; j++) {
235
+ if (j != 1) {
236
+ tabitems[i].childNodes[0].children[1].children[j].style.display=action;
237
+ }
238
+ }
239
+ } else if (tabitems_title[i].innerText.indexOf('Taiyi') >= 0) {
240
+ tabitems[i].children[0].children[0].children[1].style.display=action;
241
+ }
242
+ }
243
+ return false;
244
+ }"""
245
+
246
  def prompt_extend(prompt, PM):
247
  prompt_en = getTextTrans(prompt, source='zh', target='en')
248
  if PM == 1:
249
  extend_prompt_en = extend_prompt_pipe(prompt_en+',', num_return_sequences=1)[0]["generated_text"]
250
+ elif PM == 2:
251
  extend_prompt_en = extend_prompt_microsoft(prompt_en)
252
+ elif PM == 3:
253
+ extend_prompt_en = MagicPrompt(prompt_en)
254
 
255
  if (prompt != prompt_en):
256
+ logger.info(f"extend_prompt__1_PM=[{PM}]_")
257
  extend_prompt_out = getTextTrans(extend_prompt_en, source='en', target='zh')
258
  else:
259
+ logger.info(f"extend_prompt__2_PM=[{PM}]_")
260
  extend_prompt_out = extend_prompt_en
261
 
262
  return extend_prompt_out
 
269
  extend_prompt_out = prompt_extend(prompt, 2)
270
  return extend_prompt_out
271
 
272
+ def prompt_extend_3(prompt):
273
+ extend_prompt_out = prompt_extend(prompt, 3)
274
+ return extend_prompt_out
275
+
276
+ def prompt_draw_1(prompt, noise_level):
277
+ prompt_en = getTextTrans(prompt, source='zh', target='en')
278
+ if (prompt != prompt_en):
279
+ logger.info(f"draw_prompt______1__")
280
+ prompt_zh = prompt
281
+ else:
282
+ logger.info(f"draw_prompt______2__")
283
+ prompt_zh = getTextTrans(prompt, source='en', target='zh')
284
+
285
+ prompt_with_noise = add_random_noise(prompt_en, noise_level)
286
+ dreamlike_output = dreamlike_photoreal_2_0(prompt_with_noise)
287
+ return prompt_en, prompt_zh, dreamlike_output
288
+
289
+ def prompt_draw_2(prompt):
290
  prompt_en = getTextTrans(prompt, source='zh', target='en')
291
  if (prompt != prompt_en):
292
  logger.info(f"draw_prompt______1__")
 
294
  else:
295
  logger.info(f"draw_prompt______2__")
296
  prompt_zh = getTextTrans(prompt, source='en', target='zh')
 
297
  return prompt_en, prompt_zh
298
+
299
  with gr.Blocks(title='Text-to-Image') as demo:
300
  with gr.Group(elem_id="page_1", visible=True) as page_1:
301
  with gr.Box():
 
309
  with gr.Row(elem_id="input_col1_row1"):
310
  prompt_input0 = gr.Textbox(lines=2, label="Original prompt", visible=True)
311
  with gr.Row(elem_id="input_col1_row2"):
312
+ prompt_work = gr.Textbox(lines=1, label="prompt_work", elem_id="prompt_work", visible=True)
313
+ with gr.Row(elem_id="input_col1_row3"):
314
  with gr.Column(elem_id="input_col1_row2_col0"):
315
  draw_btn_0 = gr.Button(value = "Generate(original)", elem_id="draw-btn-0")
316
+ if extend_prompt_1:
317
+ with gr.Column(elem_id="input_col1_row2_col1"):
318
+ extend_btn_1 = gr.Button(value = "Extend_1",elem_id="extend-btn-1")
319
+ if extend_prompt_2:
320
+ with gr.Column(elem_id="input_col1_row2_col2"):
321
+ extend_btn_2 = gr.Button(value = "Extend_2",elem_id="extend-btn-2")
322
+ if extend_prompt_3:
323
+ with gr.Column(elem_id="input_col1_row2_col3"):
324
+ extend_btn_3 = gr.Button(value = "Extend_3",elem_id="extend-btn-3")
325
  with gr.Column(id="input_col2"):
326
  prompt_input1 = gr.Textbox(lines=2, label="Extend prompt", visible=True)
327
  draw_btn_1 = gr.Button(value = "Generate(extend)", elem_id="draw-btn-1")
328
+ with gr.Row(elem_id="prompt_row1"):
329
+ with gr.Column(id="input_col3"):
330
+ with gr.Row(elem_id="input_col3_row2"):
331
+ prompt_work_zh = gr.Textbox(lines=1, label="prompt_work_zh", elem_id="prompt_work_zh", visible=False)
332
  with gr.Row(elem_id='tab_demo', visible=True).style(height=200):
333
  tab_demo = gr.TabbedInterface(tab_actions, tab_titles)
334
+ if do_dreamlike_photoreal:
335
+ with gr.Row():
336
+ noise_level=gr.Slider(minimum=0.1, maximum=3, step=0.1, label="Dreamlike noise Level: [Higher noise level produces more diverse outputs, while lower noise level produces similar outputs.]")
337
+ with gr.Row():
338
+ switch_tab_advanced_btn = gr.Button(value = "Switch_tab_advanced", elem_id="switch_tab_advanced_btn")
339
+ switch_tab_advanced_btn.click(fn=None, inputs=[], outputs=[], _js=switch_tab_advanced)
340
+ with gr.Row():
341
  gr.HTML(f"<p>{thanks_info}</p>")
342
 
343
+ if extend_prompt_1:
344
+ extend_btn_1.click(fn=prompt_extend_1, inputs=[prompt_input0], outputs=[prompt_input1])
345
+ if extend_prompt_2:
346
+ extend_btn_2.click(fn=prompt_extend_2, inputs=[prompt_input0], outputs=[prompt_input1])
347
+ if extend_prompt_3:
348
+ extend_btn_3.click(fn=prompt_extend_3, inputs=[prompt_input0], outputs=[prompt_input1])
349
+
350
+ if do_dreamlike_photoreal:
351
+ draw_btn_0.click(fn=prompt_draw_1, inputs=[prompt_input0, noise_level], outputs=[prompt_work, prompt_work_zh, dreamlike_image])
352
+ draw_btn_1.click(fn=prompt_draw_1, inputs=[prompt_input1, noise_level], outputs=[prompt_work, prompt_work_zh, dreamlike_image])
353
+ else:
354
+ draw_btn_0.click(fn=prompt_draw_2, inputs=[prompt_input0], outputs=[prompt_work, prompt_work_zh])
355
+ draw_btn_1.click(fn=prompt_draw_2, inputs=[prompt_input1], outputs=[prompt_work, prompt_work_zh])
356
 
357
  demo.launch()