DamarJati commited on
Commit
9d7496d
Β·
verified Β·
1 Parent(s): 41de93f

Rename modules/app (2).py to modules/wdtagger.py

Browse files
modules/{app (2).py β†’ wdtagger.py} RENAMED
@@ -8,13 +8,6 @@ import onnxruntime as rt
8
  import pandas as pd
9
  from PIL import Image
10
 
11
- TITLE = "WaifuDiffusion Tagger"
12
- DESCRIPTION = """
13
- Demo for the WaifuDiffusion tagger models
14
-
15
- Example image by [γ»γ—β˜†β˜†β˜†](https://www.pixiv.net/en/users/43565085)
16
- """
17
-
18
  # Dataset v3 series of models:
19
  SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
20
  CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
@@ -218,122 +211,21 @@ class Predictor:
218
  return sorted_general_strings, rating, character_res, general_res
219
 
220
 
221
- def main():
222
- args = parse_args()
223
-
224
- predictor = Predictor()
225
-
226
- dropdown_list = [
227
- SWINV2_MODEL_DSV3_REPO,
228
- CONV_MODEL_DSV3_REPO,
229
- VIT_MODEL_DSV3_REPO,
230
- VIT_LARGE_MODEL_DSV3_REPO,
231
- EVA02_LARGE_MODEL_DSV3_REPO,
232
- MOAT_MODEL_DSV2_REPO,
233
- SWIN_MODEL_DSV2_REPO,
234
- CONV_MODEL_DSV2_REPO,
235
- CONV2_MODEL_DSV2_REPO,
236
- VIT_MODEL_DSV2_REPO,
237
- ]
238
-
239
- with gr.Blocks(title=TITLE) as demo:
240
- with gr.Column():
241
- gr.Markdown(
242
- value=f"<h1 style='text-align: center; margin-bottom: 1rem'>{TITLE}</h1>"
243
- )
244
- gr.Markdown(value=DESCRIPTION)
245
- with gr.Row():
246
- with gr.Column(variant="panel"):
247
- image = gr.Image(type="pil", image_mode="RGBA", label="Input")
248
- model_repo = gr.Dropdown(
249
- dropdown_list,
250
- value=SWINV2_MODEL_DSV3_REPO,
251
- label="Model",
252
- )
253
- with gr.Row():
254
- general_thresh = gr.Slider(
255
- 0,
256
- 1,
257
- step=args.score_slider_step,
258
- value=args.score_general_threshold,
259
- label="General Tags Threshold",
260
- scale=3,
261
- )
262
- general_mcut_enabled = gr.Checkbox(
263
- value=False,
264
- label="Use MCut threshold",
265
- scale=1,
266
- )
267
- with gr.Row():
268
- character_thresh = gr.Slider(
269
- 0,
270
- 1,
271
- step=args.score_slider_step,
272
- value=args.score_character_threshold,
273
- label="Character Tags Threshold",
274
- scale=3,
275
- )
276
- character_mcut_enabled = gr.Checkbox(
277
- value=False,
278
- label="Use MCut threshold",
279
- scale=1,
280
- )
281
- with gr.Row():
282
- clear = gr.ClearButton(
283
- components=[
284
- image,
285
- model_repo,
286
- general_thresh,
287
- general_mcut_enabled,
288
- character_thresh,
289
- character_mcut_enabled,
290
- ],
291
- variant="secondary",
292
- size="lg",
293
- )
294
- submit = gr.Button(value="Submit", variant="primary", size="lg")
295
- with gr.Column(variant="panel"):
296
- sorted_general_strings = gr.Textbox(label="Output (string)")
297
- rating = gr.Label(label="Rating")
298
- character_res = gr.Label(label="Output (characters)")
299
- general_res = gr.Label(label="Output (tags)")
300
- clear.add(
301
- [
302
- sorted_general_strings,
303
- rating,
304
- character_res,
305
- general_res,
306
- ]
307
- )
308
-
309
- submit.click(
310
- predictor.predict,
311
- inputs=[
312
- image,
313
- model_repo,
314
- general_thresh,
315
- general_mcut_enabled,
316
- character_thresh,
317
- character_mcut_enabled,
318
- ],
319
- outputs=[sorted_general_strings, rating, character_res, general_res],
320
- )
321
 
322
- gr.Examples(
323
- [["power.jpg", SWINV2_MODEL_DSV3_REPO, 0.35, False, 0.85, False]],
324
- inputs=[
325
- image,
326
- model_repo,
327
- general_thresh,
328
- general_mcut_enabled,
329
- character_thresh,
330
- character_mcut_enabled,
331
- ],
332
- )
333
-
334
- demo.queue(max_size=10)
335
- demo.launch()
336
 
 
 
 
 
 
 
 
 
 
 
 
 
337
 
338
- if __name__ == "__main__":
339
- main()
 
8
  import pandas as pd
9
  from PIL import Image
10
 
 
 
 
 
 
 
 
11
  # Dataset v3 series of models:
12
  SWINV2_MODEL_DSV3_REPO = "SmilingWolf/wd-swinv2-tagger-v3"
13
  CONV_MODEL_DSV3_REPO = "SmilingWolf/wd-convnext-tagger-v3"
 
211
  return sorted_general_strings, rating, character_res, general_res
212
 
213
 
214
+ args = parse_args()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
+ predictor = Predictor()
 
 
 
 
 
 
 
 
 
 
 
 
 
217
 
218
+ dropdown_list = [
219
+ SWINV2_MODEL_DSV3_REPO,
220
+ CONV_MODEL_DSV3_REPO,
221
+ VIT_MODEL_DSV3_REPO,
222
+ VIT_LARGE_MODEL_DSV3_REPO,
223
+ EVA02_LARGE_MODEL_DSV3_REPO,
224
+ MOAT_MODEL_DSV2_REPO,
225
+ SWIN_MODEL_DSV2_REPO,
226
+ CONV_MODEL_DSV2_REPO,
227
+ CONV2_MODEL_DSV2_REPO,
228
+ VIT_MODEL_DSV2_REPO,
229
+ ]
230
 
231
+