File size: 3,486 Bytes
f7165bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import gradio as gr
import onnxruntime as ort
import numpy as np
from PIL import Image
import json
from huggingface_hub import hf_hub_download

# Load the ONNX model and metadata once at startup (optimizes performance)
MODEL_REPO = "AngelBottomless/camie-tagger-onnxruntime"
MODEL_FILE = "camie_tagger_initial.onnx"   # using the smaller initial model for speed
META_FILE = "metadata.json"

# Download model and metadata from HF Hub (cache_dir="." will cache in the Space)
model_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILE, cache_dir=".")
meta_path = hf_hub_download(repo_id=MODEL_REPO, filename=META_FILE, cache_dir=".")
session = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"])
metadata = json.load(open(meta_path, "r", encoding="utf-8"))

# Preprocessing: resize image to 512x512 and normalize to match training
def preprocess_image(pil_image: Image.Image) -> np.ndarray:
    img = pil_image.convert("RGB").resize((512, 512))
    arr = np.array(img).astype(np.float32) / 255.0  # scale pixel values to [0,1]
    arr = np.transpose(arr, (2, 0, 1))  # HWC -> CHW
    arr = np.expand_dims(arr, 0)        # add batch dimension -> (1,3,512,512)
    return arr

# Inference: run the ONNX model and collect tags above threshold
def predict_tags(pil_image: Image.Image) -> str:
    # 1. Preprocess image to numpy
    input_tensor = preprocess_image(pil_image)
    # 2. Run model (both initial and refined logits are output)
    input_name = session.get_inputs()[0].name
    initial_logits, refined_logits = session.run(None, {input_name: input_tensor})
    # 3. Convert logits to probabilities (using sigmoid since multi-label)
    probs = 1 / (1 + np.exp(-refined_logits))  # shape (1, 70527)
    probs = probs[0]  # remove batch dim -> (70527,)
    # 4. Thresholding: get tag names for which probability >= category threshold (or default)
    idx_to_tag = metadata["idx_to_tag"]               # map index -> tag string
    tag_to_category = metadata.get("tag_to_category", {})       # map tag -> category
    category_thresholds = metadata.get("category_thresholds", {})# category-specific thresholds
    default_threshold = 0.325
    predicted_tags = []
    for idx, prob in enumerate(probs):
        tag = idx_to_tag[str(idx)]
        cat = tag_to_category.get(tag, "unknown")
        threshold = category_thresholds.get(cat, default_threshold)
        if prob >= threshold:
            # Include this tag; replace underscores with spaces for readability
            predicted_tags.append(tag.replace("_", " "))
    # 5. Return tags as comma-separated string
    if not predicted_tags:
        return "No tags found."
    # Join tags, maybe sorted by name or leave unsorted. Here we sort alphabetically for consistency.
    predicted_tags.sort()
    return ", ".join(predicted_tags)

# Create a simple Gradio interface
demo = gr.Interface(
    fn=predict_tags,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs=gr.Textbox(label="Predicted Tags", lines=3),
    title="Camie Tagger (ONNX) – Simple Demo",
    description="Upload an anime/manga illustration to get relevant tags predicted by the Camie Tagger model.",
    # You can optionally add example images if available in the Space directory:
    examples=[["example1.jpg"], ["example2.png"]]  # (filenames should exist in the Space)
)

# Launch the app (in HF Spaces, just calling demo.launch() is typically not required; the Space will run app automatically)
demo.launch()