chaore johnsu6616 commited on
Commit
3a59ab0
·
0 Parent(s):

Duplicate from johnsu6616/SD_Helper_01

Browse files

Co-authored-by: johnsu <[email protected]>

Files changed (4) hide show
  1. .gitattributes +34 -0
  2. README.md +14 -0
  3. app.py +309 -0
  4. requirements.txt +6 -0
.gitattributes ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tflite filter=lfs diff=lfs merge=lfs -text
29
+ *.tgz filter=lfs diff=lfs merge=lfs -text
30
+ *.wasm filter=lfs diff=lfs merge=lfs -text
31
+ *.xz filter=lfs diff=lfs merge=lfs -text
32
+ *.zip filter=lfs diff=lfs merge=lfs -text
33
+ *.zst filter=lfs diff=lfs merge=lfs -text
34
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SD_Helper_01
3
+ emoji: 📊
4
+ colorFrom: gray
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.30.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: openrail
11
+ duplicated_from: johnsu6616/SD_Helper_01
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import re
3
+
4
+ import gradio as gr
5
+ import torch
6
+
7
+ from transformers import AutoModelForCausalLM
8
+ from transformers import AutoModelForSeq2SeqLM
9
+ from transformers import AutoTokenizer
10
+
11
+ from transformers import AutoProcessor
12
+
13
+ from transformers import pipeline
14
+
15
+ from transformers import set_seed
16
+
17
+ global ButtonIndex
18
+
19
+ device = "cuda" if torch.cuda.is_available() else "cpu"
20
+
21
+ big_processor = AutoProcessor.from_pretrained("microsoft/git-base-coco")
22
+ big_model = AutoModelForCausalLM.from_pretrained("microsoft/git-base-coco")
23
+
24
+ pipeline_01 = pipeline('text-generation', model='succinctly/text2image-prompt-generator', max_new_tokens=256)
25
+ pipeline_02 = pipeline('text-generation', model='Gustavosta/MagicPrompt-Stable-Diffusion', max_new_tokens=256)
26
+ pipeline_03 = pipeline('text-generation', model='johnsu6616/ModelExport', max_new_tokens=256)
27
+
28
+ zh2en_model = AutoModelForSeq2SeqLM.from_pretrained('Helsinki-NLP/opus-mt-zh-en').eval()
29
+ zh2en_tokenizer = AutoTokenizer.from_pretrained('Helsinki-NLP/opus-mt-zh-en')
30
+
31
+ en2zh_model = AutoModelForSeq2SeqLM.from_pretrained("Helsinki-NLP/opus-mt-en-zh").eval()
32
+ en2zh_tokenizer = AutoTokenizer.from_pretrained("Helsinki-NLP/opus-mt-en-zh")
33
+
34
+ def translate_zh2en(text):
35
+ with torch.no_grad():
36
+ text = re.sub(r"[:\-–.!;?_#]", '', text)
37
+
38
+ text = re.sub(r'([^\u4e00-\u9fa5])([\u4e00-\u9fa5])', r'\1\n\2', text)
39
+ text = re.sub(r'([\u4e00-\u9fa5])([^\u4e00-\u9fa5])', r'\1\n\2', text)
40
+
41
+ text = text.replace('\n', ',')
42
+
43
+ text =re.sub(r'(?<![a-zA-Z])\s+|\s+(?![a-zA-Z])', '', text)
44
+
45
+ text = re.sub(r',+', ',', text)
46
+
47
+ encoded = zh2en_tokenizer([text], return_tensors='pt')
48
+ sequences = zh2en_model.generate(**encoded)
49
+ result = zh2en_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
50
+
51
+ result = result.strip()
52
+
53
+ if result == "No,no," :
54
+ result = text
55
+
56
+ result = re.sub(r'<.*?>', '', result)
57
+
58
+ result = re.sub(r'\b(\w+)\b(?:\W+\1\b)+', r'\1', result, flags=re.IGNORECASE)
59
+ return result
60
+
61
+
62
+ def translate_en2zh(text):
63
+ with torch.no_grad():
64
+
65
+ encoded = en2zh_tokenizer([text], return_tensors="pt")
66
+ sequences = en2zh_model.generate(**encoded)
67
+ result = en2zh_tokenizer.batch_decode(sequences, skip_special_tokens=True)[0]
68
+
69
+ result = re.sub(r'\b(\w+)\b(?:\W+\1\b)+', r'\1', result, flags=re.IGNORECASE)
70
+ return result
71
+
72
+ def load_prompter():
73
+ prompter_model = AutoModelForCausalLM.from_pretrained("microsoft/Promptist")
74
+ tokenizer = AutoTokenizer.from_pretrained("gpt2")
75
+ tokenizer.pad_token = tokenizer.eos_token
76
+ tokenizer.padding_side = "left"
77
+ return prompter_model, tokenizer
78
+
79
+ prompter_model, prompter_tokenizer = load_prompter()
80
+
81
+
82
+ def generate_prompter_pipeline_01(text):
83
+ seed = random.randint(100, 1000000)
84
+ set_seed(seed)
85
+ text_in_english = translate_zh2en(text)
86
+ response = pipeline_01(text_in_english, num_return_sequences=3)
87
+ response_list = []
88
+ for x in response:
89
+ resp = x['generated_text'].strip()
90
+
91
+ if resp != text_in_english and len(resp) > (len(text_in_english) + 4):
92
+
93
+ response_list.append(translate_en2zh(resp)+"\n")
94
+ response_list.append(resp+"\n")
95
+ response_list.append("\n")
96
+
97
+ result = "".join(response_list)
98
+ result = re.sub('[^ ]+\.[^ ]+','', result)
99
+ result = result.replace("<", "").replace(">", "")
100
+
101
+ if result != "":
102
+ return result
103
+
104
+
105
+ def generate_prompter_tokenizer_01(text):
106
+
107
+ text_in_english = translate_zh2en(text)
108
+
109
+ input_ids = prompter_tokenizer(text_in_english.strip()+" Rephrase:", return_tensors="pt").input_ids
110
+
111
+ outputs = prompter_model.generate(
112
+ input_ids,
113
+ do_sample=False,
114
+
115
+ num_beams=3,
116
+ num_return_sequences=3,
117
+ pad_token_id= 50256,
118
+ eos_token_id = 50256,
119
+ length_penalty=-1.0
120
+ )
121
+ output_texts = prompter_tokenizer.batch_decode(outputs, skip_special_tokens=True)
122
+
123
+ result = []
124
+ for output_text in output_texts:
125
+
126
+ output_text = output_text.replace('<', '').replace('>', '')
127
+ output_text = output_text.split("Rephrase:", 1)[-1].strip()
128
+
129
+ result.append(translate_en2zh(output_text)+"\n")
130
+ result.append(output_text+"\n")
131
+ result.append("\n")
132
+ return "".join(result)
133
+
134
+ def generate_prompter_pipeline_02(text):
135
+ seed = random.randint(100, 1000000)
136
+ set_seed(seed)
137
+ text_in_english = translate_zh2en(text)
138
+ response = pipeline_02(text_in_english, num_return_sequences=3)
139
+ response_list = []
140
+ for x in response:
141
+ resp = x['generated_text'].strip()
142
+ if resp != text_in_english and len(resp) > (len(text_in_english) + 4):
143
+
144
+ response_list.append(translate_en2zh(resp)+"\n")
145
+ response_list.append(resp+"\n")
146
+ response_list.append("\n")
147
+
148
+ result = "".join(response_list)
149
+ result = re.sub('[^ ]+\.[^ ]+','', result)
150
+ result = result.replace("<", "").replace(">", "")
151
+
152
+ if result != "":
153
+ return result
154
+
155
+ def generate_prompter_pipeline_03(text):
156
+ seed = random.randint(100, 1000000)
157
+ set_seed(seed)
158
+ text_in_english = translate_zh2en(text)
159
+ response = pipeline_03(text_in_english, num_return_sequences=3)
160
+ response_list = []
161
+ for x in response:
162
+ resp = x['generated_text'].strip()
163
+ if resp != text_in_english and len(resp) > (len(text_in_english) + 4):
164
+
165
+ response_list.append(translate_en2zh(resp)+"\n")
166
+ response_list.append(resp+"\n")
167
+ response_list.append("\n")
168
+
169
+ result = "".join(response_list)
170
+ result = re.sub('[^ ]+\.[^ ]+','', result)
171
+ result = result.replace("<", "").replace(">", "")
172
+
173
+ if result != "":
174
+ return result
175
+
176
+ def generate_render(text,choice):
177
+ if choice == '★pipeline模式(succinctly)':
178
+ outputs = generate_prompter_pipeline_01(text)
179
+ return outputs,choice
180
+ elif choice == '★★tokenizer模式':
181
+ outputs = generate_prompter_tokenizer_01(text)
182
+ return outputs,choice
183
+ elif choice == '★★★pipeline模型(Gustavosta)':
184
+ outputs = generate_prompter_pipeline_02(text)
185
+ return outputs,choice
186
+ elif choice == 'pipeline模型(John)_自訓測試,資料不穩定':
187
+ outputs = generate_prompter_pipeline_03(text)
188
+ return outputs,choice
189
+
190
+ def get_prompt_from_image(input_image,choice):
191
+ image = input_image.convert('RGB')
192
+ pixel_values = big_processor(images=image, return_tensors="pt").to(device).pixel_values
193
+ generated_ids = big_model.to(device).generate(pixel_values=pixel_values)
194
+ generated_caption = big_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
195
+ text = re.sub(r"[:\-–.!;?_#]", '', generated_caption)
196
+
197
+ if choice == '★pipeline模式(succinctly)':
198
+ outputs = generate_prompter_pipeline_01(text)
199
+ return outputs
200
+ elif choice == '★★tokenizer模式':
201
+ outputs = generate_prompter_tokenizer_01(text)
202
+ return outputs
203
+ elif choice == '★★★pipeline模型(Gustavosta)':
204
+ outputs = generate_prompter_pipeline_02(text)
205
+ return outputs
206
+ elif choice == 'pipeline模型(John)_自訓測試,資料不穩定':
207
+ outputs = generate_prompter_pipeline_03(text)
208
+ return outputs
209
+
210
+ with gr.Blocks() as block:
211
+ with gr.Column():
212
+ with gr.Tab('工作區'):
213
+ with gr.Row():
214
+ input_text = gr.Textbox(lines=12, label='輸入文字', placeholder='在此输入文字...')
215
+ input_image = gr.Image(type='pil', label="選擇圖片(辨識度不佳)")
216
+ with gr.Row():
217
+ txt_prompter_btn = gr.Button('文生文')
218
+ pic_prompter_btn = gr.Button('圖生文')
219
+ with gr.Row():
220
+ radio_btn = gr.Radio(
221
+ label="請選擇產出方式",
222
+ choices=['★pipeline模式(succinctly)', '★★tokenizer模式', '★★★pipeline模型(Gustavosta)',
223
+ 'pipeline模型(John)_自訓測試,資料不穩定'],
224
+
225
+ value='★pipeline模式(succinctly)'
226
+ )
227
+
228
+ with gr.Row():
229
+ Textbox_1 = gr.Textbox(lines=6, label='提示詞生成')
230
+ with gr.Row():
231
+ Textbox_2 = gr.Textbox(lines=6, label='測試資訊')
232
+
233
+ with gr.Tab('測試區'):
234
+ with gr.Row():
235
+ input_test01 = gr.Textbox(lines=2, label='中英翻譯', placeholder='在此输入文字...')
236
+ test01_btn = gr.Button('執行')
237
+ Textbox_test01 = gr.Textbox(lines=2, label='輸出結果')
238
+ with gr.Row():
239
+ input_test02 = gr.Textbox(lines=2, label='英中翻譯(不精準)', placeholder='在此输入文字...')
240
+ test02_btn = gr.Button('執行')
241
+ Textbox_test02 = gr.Textbox(lines=2, label='輸出結果')
242
+ with gr.Row():
243
+ input_test03 = gr.Textbox(lines=2, label='★pipeline模式(succinctly)', placeholder='在此输入文字...')
244
+ test03_btn = gr.Button('執行')
245
+ Textbox_test03 = gr.Textbox(lines=2, label='輸出結果')
246
+ with gr.Row():
247
+ input_test04 = gr.Textbox(lines=2, label='★★tokenizer模式', placeholder='在此输入文字...')
248
+ test04_btn = gr.Button('執行')
249
+ Textbox_test04 = gr.Textbox(lines=2, label='輸出結果')
250
+ with gr.Row():
251
+ input_test05 = gr.Textbox(lines=2, label='★★★pipeline模型(Gustavosta)', placeholder='在此输入文字...')
252
+ test05_btn = gr.Button('執行')
253
+ Textbox_test05 = gr.Textbox(lines=2, label='輸出結果')
254
+ with gr.Row():
255
+ input_test06 = gr.Textbox(lines=2, label='pipeline模型(John)_自訓測試,資料不穩定', placeholder='在此输入文字...')
256
+ test06_btn = gr.Button('執行')
257
+ Textbox_test06 = gr.Textbox(lines=2, label='輸出結果')
258
+
259
+ txt_prompter_btn.click (
260
+ fn=generate_render,
261
+ inputs=[input_text,radio_btn],
262
+ outputs=[Textbox_1,Textbox_2]
263
+ )
264
+
265
+ pic_prompter_btn.click(
266
+ fn=get_prompt_from_image,
267
+ inputs=[input_image,radio_btn],
268
+ outputs=Textbox_1
269
+ )
270
+
271
+ test01_btn.click(
272
+ fn=translate_zh2en,
273
+ inputs=input_test01,
274
+ outputs=Textbox_test01
275
+ )
276
+
277
+ test02_btn.click(
278
+ fn=translate_en2zh,
279
+ inputs=input_test02,
280
+ outputs=Textbox_test02
281
+ )
282
+
283
+ test03_btn.click(
284
+ fn= generate_prompter_pipeline_01,
285
+ inputs=input_test03,
286
+ outputs=Textbox_test03
287
+ )
288
+
289
+ test04_btn.click(
290
+ fn= generate_prompter_tokenizer_01,
291
+ inputs=input_test04,
292
+ outputs=Textbox_test04
293
+ )
294
+
295
+ test05_btn.click(
296
+ fn= generate_prompter_pipeline_02,
297
+ inputs=input_test05,
298
+ outputs=Textbox_test05
299
+ )
300
+
301
+
302
+ test06_btn.click(
303
+ fn= generate_prompter_pipeline_03,
304
+ inputs= input_test06,
305
+ outputs= Textbox_test06
306
+ )
307
+
308
+ block.queue(max_size=64).launch(show_api=False, enable_queue=True, debug=True, share=False, server_name='0.0.0.0')
309
+
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ transformers==4.29.2
2
+ torch==2.0.0
3
+ pytorch_lightning==2.0.2
4
+ gradio==3.30.0
5
+ sentencepiece==0.1.99
6
+ sacremoses==0.0.53