File size: 5,325 Bytes
086766e
 
 
 
 
 
 
7ec5b17
086766e
7ec5b17
086766e
 
 
 
 
7ec5b17
086766e
 
7ec5b17
 
 
086766e
 
7ec5b17
 
 
086766e
 
 
7ec5b17
 
 
 
 
086766e
7ec5b17
 
 
086766e
 
 
7ec5b17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
086766e
7ec5b17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
086766e
7ec5b17
086766e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image
import json
from huggingface_hub import hf_hub_download

# Load model and metadata at startup (same as before)
MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
MODEL_FILE = "camie_tagger_initial.onnx"
META_FILE = "metadata.json"
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
meta_path = hf_hub_download(repo_id=MODEL_REPO, filename=META_FILE, cache_dir=".")
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
metadata = json.load(open(meta_path, "r", encoding="utf-8"))
# Preprocessing function (same as before)
def preprocess_image(pil_image: Image.Image) -> np.ndarray:
    img = pil_image.convert("RGB").resize((512, 512))
    arr = np.array(img).astype(np.float32) / 255.0
    arr = np.transpose(arr, (2, 0, 1))
    arr = np.expand_dims(arr, 0)
    return arr

# Inference function with output format option
def tag_image(pil_image: Image.Image, output_format: str) -> str:
    # Run model inference
    input_tensor = preprocess_image(pil_image)
    input_name = session.get_inputs()[0].name
    initial_logits, refined_logits = session.run(None, {input_name: input_tensor})
    probs = 1 / (1 + np.exp(-refined_logits))
    probs = probs[0]
    idx_to_tag = metadata["idx_to_tag"]
    tag_to_category = metadata.get("tag_to_category", {})
    category_thresholds = metadata.get("category_thresholds", {})
    default_threshold = 0.325
    results_by_cat = {}  # to store tags per category (for verbose output)
    prompt_tags = []     # to store tags for prompt-style output
    # Collect tags above thresholds
    for idx, prob in enumerate(probs):
        tag = idx_to_tag[str(idx)]
        cat = tag_to_category.get(tag, "unknown")
        thresh = category_thresholds.get(cat, default_threshold)
        if float(prob) >= thresh:
            # add to category dictionary
            results_by_cat.setdefault(cat, []).append((tag, float(prob)))
            # add to prompt list
            prompt_tags.append(tag.replace("_", " "))
    if output_format == "Prompt-style Tags":
        if not prompt_tags:
            return "No tags predicted."
        # Join tags with commas (sorted by probability for relevance)
        # Sort prompt_tags by probability from results_by_cat (for better prompts ordering)
        prompt_tags.sort(key=lambda t: max([p for (tg, p) in results_by_cat[tag_to_category.get(t.replace(' ', '_'), 'unknown')] if tg == t.replace(' ', '_')]), reverse=True)
        return ", ".join(prompt_tags)
    else:  # Detailed output
        if not results_by_cat:
            return "No tags predicted for this image."
        lines = []
        lines.append("**Predicted Tags by Category:**  \n")  # (Markdown newline: two spaces + newline)
        for cat, tag_list in results_by_cat.items():
            # sort tags in this category by probability descending
            tag_list.sort(key=lambda x: x[1], reverse=True)
            lines.append(f"**Category: {cat}** – {len(tag_list)} tags")
            for tag, prob in tag_list:
                tag_pretty = tag.replace("_", " ")
                lines.append(f"- {tag_pretty} (Prob: {prob:.3f})")
            lines.append("")  # blank line between categories
        return "\n".join(lines)

# Build the Gradio Blocks UI
demo = gr.Blocks(theme=gr.themes.Soft())  # using a built-in theme for nicer styling

with demo:
    # Header Section
    gr.Markdown("# 🏷️ Camie Tagger – Anime Image Tagging\nThis demo uses an ONNX model of Camie Tagger to label anime illustrations with tags. Upload an image and click **Tag Image** to see predictions.")
    gr.Markdown("*(Note: The model will predict a large number of tags across categories like character, general, artist, etc. You can choose a concise prompt-style output or a detailed category-wise breakdown.)*")
    # Input/Output Section
    with gr.Row():
        # Left column: Image input and format selection
        with gr.Column():
            image_in = gr.Image(type="pil", label="Input Image")
            format_choice = gr.Radio(choices=["Prompt-style Tags", "Detailed Output"], value="Prompt-style Tags", label="Output Format")
            tag_button = gr.Button("πŸ” Tag Image")
        # Right column: Output display
        with gr.Column():
            output_box = gr.Markdown("")  # will display the result in Markdown (supports bold, lists, etc.)
    # Example images (if available in the repo)
    gr.Examples(
        examples=[["example1.jpg"], ["example2.png"]],  # Example file paths (ensure these exist in the Space)
        inputs=image_in,
        outputs=output_box,
        fn=tag_image,
        cache_examples=True
    )
    # Link the button click to the function
    tag_button.click(fn=tag_image, inputs=[image_in, format_choice], outputs=output_box)
    # Footer/Info
    gr.Markdown("----\n**Model:** [Camie Tagger ONNX](https://huggingface.co/AngelBottomless/camie-tagger-onnxruntime)   β€’   **Base Model:** Camais03/camie-tagger (61% F1 on 70k tags)   β€’   **ONNX Runtime:** for efficient CPU inference​:contentReference[oaicite:6]{index=6}   β€’   *Demo built with Gradio Blocks.*")

# Launch the app (automatically handled in Spaces)
demo.launch()