Upload 9 files
Browse files- export.py +30 -0
- infer-refined.py +89 -35
- infer.py +139 -97
- model_code.py +956 -0
- model_config.json +9 -0
- model_info_initial_only.json +9 -0
- model_no_flash.py +195 -0
- thresholds.json +170 -0
export.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision.models as models
|
| 3 |
+
from model_code import InitialOnlyImageTagger # Assume model_code.py classes are accessible
|
| 4 |
+
from safetensors.torch import load_file
|
| 5 |
+
|
| 6 |
+
# Load the trained weights (Initial-only model). Adjust path to your weights file.
|
| 7 |
+
#weights_path = "model_initial_only.pt"
|
| 8 |
+
safetensors_path = 'model_initial.safetensors'
|
| 9 |
+
state_dict = load_file(safetensors_path, device='cpu')
|
| 10 |
+
#state_dict = torch.load(weights_path, map_location="cpu")
|
| 11 |
+
# Instantiate the model with the same parameters as training
|
| 12 |
+
model = InitialOnlyImageTagger(total_tags=70527, dataset=None, pretrained=True) # dataset not needed for forward
|
| 13 |
+
model.load_state_dict(state_dict)
|
| 14 |
+
model.eval() # set to evaluation mode
|
| 15 |
+
|
| 16 |
+
# Define example input – a dummy image tensor of the expected input shape (1, 3, 512, 512)
|
| 17 |
+
dummy_input = torch.randn(1, 3, 512, 512, dtype=torch.float32)
|
| 18 |
+
|
| 19 |
+
# Export to ONNX
|
| 20 |
+
onnx_path = "camie_tagger_initial_v15.onnx"
|
| 21 |
+
torch.onnx.export(
|
| 22 |
+
model, dummy_input, onnx_path,
|
| 23 |
+
export_params=True, # store the trained parameter weights in the model file
|
| 24 |
+
opset_version=13, # ONNX opset version (13 is widely supported)
|
| 25 |
+
do_constant_folding=True, # optimize constant expressions
|
| 26 |
+
input_names=["input"],
|
| 27 |
+
output_names=["initial_logits", "refined_logits"], # model.forward returns two outputs (identical for InitialOnly)
|
| 28 |
+
dynamic_axes={"input": {0: "batch_size"}} # allow variable batch size
|
| 29 |
+
)
|
| 30 |
+
print(f"ONNX model saved to: {onnx_path}")
|
infer-refined.py
CHANGED
|
@@ -42,73 +42,120 @@ def preprocess_image(img_path, target_size=512, keep_aspect=True):
|
|
| 42 |
arr = np.expand_dims(arr, axis=0)
|
| 43 |
return arr
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
"""
|
| 50 |
Loads the ONNX model, runs inference on a list of image paths,
|
| 51 |
-
and applies
|
| 52 |
-
|
| 53 |
Args:
|
| 54 |
-
img_paths: List of paths to images.
|
| 55 |
-
onnx_path: Path to the exported ONNX model file.
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
|
|
|
|
|
|
|
|
|
| 59 |
Returns:
|
| 60 |
-
A list of dicts, each containing:
|
| 61 |
{
|
| 62 |
"initial_logits": np.ndarray of shape (N_tags,),
|
| 63 |
"refined_logits": np.ndarray of shape (N_tags,),
|
| 64 |
-
"
|
|
|
|
| 65 |
...
|
| 66 |
}
|
| 67 |
-
one dict per input image.
|
| 68 |
"""
|
| 69 |
# 1) Initialize ONNX runtime session
|
| 70 |
session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
|
| 71 |
-
#
|
| 72 |
# session = ort.InferenceSession(onnx_path, providers=["CUDAExecutionProvider"])
|
| 73 |
|
| 74 |
# 2) Pre-load metadata
|
| 75 |
with open(metadata_file, "r", encoding="utf-8") as f:
|
| 76 |
metadata = json.load(f)
|
| 77 |
-
idx_to_tag = metadata["idx_to_tag"]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
# 3) Preprocess each image into a batch
|
| 80 |
batch_tensors = []
|
| 81 |
for img_path in img_paths:
|
| 82 |
-
x = preprocess_image(img_path, target_size=
|
| 83 |
batch_tensors.append(x)
|
| 84 |
-
# Concatenate along the batch dimension => shape (batch_size, 3,
|
| 85 |
batch_input = np.concatenate(batch_tensors, axis=0)
|
| 86 |
|
| 87 |
# 4) Run inference
|
| 88 |
-
input_name = session.get_inputs()[0].name
|
| 89 |
outputs = session.run(None, {input_name: batch_input})
|
| 90 |
# Typically we get [initial_tags, refined_tags] as output
|
| 91 |
-
initial_preds, refined_preds = outputs # shapes => (batch_size,
|
| 92 |
|
| 93 |
-
# 5)
|
| 94 |
batch_results = []
|
| 95 |
for i in range(initial_preds.shape[0]):
|
| 96 |
-
# Extract one sample's logits
|
| 97 |
init_logit = initial_preds[i, :] # shape (N_tags,)
|
| 98 |
ref_logit = refined_preds[i, :] # shape (N_tags,)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
-
#
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
|
|
|
| 105 |
|
| 106 |
# Build result for this image
|
| 107 |
result_dict = {
|
| 108 |
"initial_logits": init_logit,
|
| 109 |
"refined_logits": ref_logit,
|
| 110 |
-
"predicted_indices":
|
| 111 |
-
"predicted_tags":
|
| 112 |
}
|
| 113 |
batch_results.append(result_dict)
|
| 114 |
|
|
@@ -116,14 +163,21 @@ def onnx_inference(img_paths,
|
|
| 116 |
|
| 117 |
if __name__ == "__main__":
|
| 118 |
# Example usage
|
| 119 |
-
images = ["
|
| 120 |
-
results = onnx_inference(
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
for i, res in enumerate(results):
|
| 126 |
print(f"Image: {images[i]}")
|
| 127 |
print(f" # of predicted tags above threshold: {len(res['predicted_indices'])}")
|
| 128 |
-
|
| 129 |
-
|
|
|
|
|
|
|
|
|
| 42 |
arr = np.expand_dims(arr, axis=0)
|
| 43 |
return arr
|
| 44 |
|
| 45 |
+
# Example input
|
| 46 |
+
def load_thresholds(threshold_json_path, mode="balanced"):
|
| 47 |
+
"""
|
| 48 |
+
Loads thresholds from the given JSON file, using a particular mode
|
| 49 |
+
(e.g. 'balanced', 'high_precision', 'high_recall') for each category.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
thresholds_by_category (dict): e.g. { "general": 0.328..., "character": 0.304..., ... }
|
| 53 |
+
fallback_threshold (float): The overall threshold if category not found
|
| 54 |
+
"""
|
| 55 |
+
with open(threshold_json_path, "r", encoding="utf-8") as f:
|
| 56 |
+
data = json.load(f)
|
| 57 |
+
|
| 58 |
+
# The fallback threshold from the "overall" section for the chosen mode
|
| 59 |
+
fallback_threshold = data["overall"][mode]["threshold"]
|
| 60 |
+
|
| 61 |
+
# Build a dict of thresholds keyed by category
|
| 62 |
+
thresholds_by_category = {}
|
| 63 |
+
if "categories" in data:
|
| 64 |
+
for cat_name, cat_modes in data["categories"].items():
|
| 65 |
+
# If the chosen mode is present for that category, use it;
|
| 66 |
+
# otherwise fall back to the "overall" threshold.
|
| 67 |
+
if mode in cat_modes and "threshold" in cat_modes[mode]:
|
| 68 |
+
thresholds_by_category[cat_name] = cat_modes[mode]["threshold"]
|
| 69 |
+
else:
|
| 70 |
+
thresholds_by_category[cat_name] = fallback_threshold
|
| 71 |
+
|
| 72 |
+
return thresholds_by_category, fallback_threshold
|
| 73 |
+
def onnx_inference(
|
| 74 |
+
img_paths,
|
| 75 |
+
onnx_path="camie_refined_no_flash.onnx",
|
| 76 |
+
metadata_file="metadata.json",
|
| 77 |
+
threshold_json_path="thresholds.json",
|
| 78 |
+
mode="balanced",
|
| 79 |
+
target_size=512,
|
| 80 |
+
keep_aspect=True
|
| 81 |
+
):
|
| 82 |
"""
|
| 83 |
Loads the ONNX model, runs inference on a list of image paths,
|
| 84 |
+
and applies category-wise thresholds from threshold.json (per the chosen mode).
|
| 85 |
+
|
| 86 |
Args:
|
| 87 |
+
img_paths : List of paths to images.
|
| 88 |
+
onnx_path : Path to the exported ONNX model file.
|
| 89 |
+
metadata_file : Path to metadata.json that contains idx_to_tag, tag_to_category, etc.
|
| 90 |
+
threshold_json_path : Path to thresholds.json containing category-wise threshold info.
|
| 91 |
+
mode : "balanced", "high_precision", or "high_recall".
|
| 92 |
+
target_size : Final size of preprocessed images (512 by default).
|
| 93 |
+
keep_aspect : If True, preserve aspect ratio when resizing, pad with black.
|
| 94 |
+
|
| 95 |
Returns:
|
| 96 |
+
A list of dicts, one per input image, each containing:
|
| 97 |
{
|
| 98 |
"initial_logits": np.ndarray of shape (N_tags,),
|
| 99 |
"refined_logits": np.ndarray of shape (N_tags,),
|
| 100 |
+
"predicted_indices": list of tag indices that exceeded threshold,
|
| 101 |
+
"predicted_tags": list of predicted tag strings,
|
| 102 |
...
|
| 103 |
}
|
|
|
|
| 104 |
"""
|
| 105 |
# 1) Initialize ONNX runtime session
|
| 106 |
session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
|
| 107 |
+
# For GPU usage, you could do e.g.:
|
| 108 |
# session = ort.InferenceSession(onnx_path, providers=["CUDAExecutionProvider"])
|
| 109 |
|
| 110 |
# 2) Pre-load metadata
|
| 111 |
with open(metadata_file, "r", encoding="utf-8") as f:
|
| 112 |
metadata = json.load(f)
|
| 113 |
+
idx_to_tag = metadata["idx_to_tag"] # e.g. { "0": "brown_hair", "1": "blue_eyes", ... }
|
| 114 |
+
tag_to_category = metadata.get("tag_to_category", {})
|
| 115 |
+
|
| 116 |
+
# Load thresholds from thresholds.json using the specified mode
|
| 117 |
+
thresholds_by_category, fallback_threshold = load_thresholds(threshold_json_path, mode)
|
| 118 |
|
| 119 |
# 3) Preprocess each image into a batch
|
| 120 |
batch_tensors = []
|
| 121 |
for img_path in img_paths:
|
| 122 |
+
x = preprocess_image(img_path, target_size=target_size, keep_aspect=keep_aspect)
|
| 123 |
batch_tensors.append(x)
|
| 124 |
+
# Concatenate along the batch dimension => shape (batch_size, 3, H, W)
|
| 125 |
batch_input = np.concatenate(batch_tensors, axis=0)
|
| 126 |
|
| 127 |
# 4) Run inference
|
| 128 |
+
input_name = session.get_inputs()[0].name # typically "image" or "input"
|
| 129 |
outputs = session.run(None, {input_name: batch_input})
|
| 130 |
# Typically we get [initial_tags, refined_tags] as output
|
| 131 |
+
initial_preds, refined_preds = outputs # shapes => (batch_size, N_tags)
|
| 132 |
|
| 133 |
+
# 5) Convert logits -> probabilities -> apply category-specific thresholds
|
| 134 |
batch_results = []
|
| 135 |
for i in range(initial_preds.shape[0]):
|
|
|
|
| 136 |
init_logit = initial_preds[i, :] # shape (N_tags,)
|
| 137 |
ref_logit = refined_preds[i, :] # shape (N_tags,)
|
| 138 |
+
ref_prob = 1.0 / (1.0 + np.exp(-ref_logit)) # shape (N_tags,)
|
| 139 |
+
|
| 140 |
+
predicted_indices = []
|
| 141 |
+
predicted_tags = []
|
| 142 |
|
| 143 |
+
# Check each tag against the category threshold
|
| 144 |
+
for idx in range(ref_logit.shape[0]):
|
| 145 |
+
tag_name = idx_to_tag[str(idx)] # Convert index->string->tag name
|
| 146 |
+
category = tag_to_category.get(tag_name, "general") # fallback to "general" if missing
|
| 147 |
+
cat_threshold = thresholds_by_category.get(category, fallback_threshold)
|
| 148 |
|
| 149 |
+
if ref_prob[idx] >= cat_threshold:
|
| 150 |
+
predicted_indices.append(idx)
|
| 151 |
+
predicted_tags.append(tag_name)
|
| 152 |
|
| 153 |
# Build result for this image
|
| 154 |
result_dict = {
|
| 155 |
"initial_logits": init_logit,
|
| 156 |
"refined_logits": ref_logit,
|
| 157 |
+
"predicted_indices": predicted_indices,
|
| 158 |
+
"predicted_tags": predicted_tags,
|
| 159 |
}
|
| 160 |
batch_results.append(result_dict)
|
| 161 |
|
|
|
|
| 163 |
|
| 164 |
if __name__ == "__main__":
|
| 165 |
# Example usage
|
| 166 |
+
images = ["images.png"]
|
| 167 |
+
results = onnx_inference(
|
| 168 |
+
img_paths=images,
|
| 169 |
+
onnx_path="camie_refined_no_flash_v15.onnx",
|
| 170 |
+
metadata_file="metadata.json",
|
| 171 |
+
threshold_json_path="thresholds.json",
|
| 172 |
+
mode="balanced", # or "balanced", "high_precision"
|
| 173 |
+
target_size=512,
|
| 174 |
+
keep_aspect=True
|
| 175 |
+
)
|
| 176 |
|
| 177 |
for i, res in enumerate(results):
|
| 178 |
print(f"Image: {images[i]}")
|
| 179 |
print(f" # of predicted tags above threshold: {len(res['predicted_indices'])}")
|
| 180 |
+
# Show first 10 predicted tags (if available)
|
| 181 |
+
sample_tags = res['predicted_tags']
|
| 182 |
+
print(" Sample predicted tags:", sample_tags)
|
| 183 |
+
print()
|
infer.py
CHANGED
|
@@ -1,98 +1,140 @@
|
|
| 1 |
-
import onnxruntime as ort
|
| 2 |
-
import numpy as np
|
| 3 |
-
import json
|
| 4 |
-
from PIL import Image
|
| 5 |
-
|
| 6 |
-
# 1) Load ONNX model
|
| 7 |
-
session = ort.InferenceSession("
|
| 8 |
-
|
| 9 |
-
# 2) Preprocess your image (512x512, etc.)
|
| 10 |
-
def preprocess_image(img_path):
|
| 11 |
-
"""
|
| 12 |
-
Loads and resizes an image to 512x512, converts it to float32 [0..1],
|
| 13 |
-
and returns a (1,3,512,512) NumPy array (NCHW format).
|
| 14 |
-
"""
|
| 15 |
-
img = Image.open(img_path).convert("RGB").resize((512, 512))
|
| 16 |
-
x = np.array(img).astype(np.float32) / 255.0
|
| 17 |
-
x = np.transpose(x, (2, 0, 1)) # HWC -> CHW
|
| 18 |
-
x = np.expand_dims(x, 0) # add batch dimension -> (1,3,512,512)
|
| 19 |
-
return x
|
| 20 |
-
|
| 21 |
-
# Example input
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
#
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
print(result)
|
|
|
|
| 1 |
+
import onnxruntime as ort
|
| 2 |
+
import numpy as np
|
| 3 |
+
import json
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
# 1) Load ONNX model
|
| 7 |
+
session = ort.InferenceSession("camie_tagger_initial_v15.onnx", providers=["CPUExecutionProvider"])
|
| 8 |
+
|
| 9 |
+
# 2) Preprocess your image (512x512, etc.)
|
| 10 |
+
def preprocess_image(img_path):
|
| 11 |
+
"""
|
| 12 |
+
Loads and resizes an image to 512x512, converts it to float32 [0..1],
|
| 13 |
+
and returns a (1,3,512,512) NumPy array (NCHW format).
|
| 14 |
+
"""
|
| 15 |
+
img = Image.open(img_path).convert("RGB").resize((512, 512))
|
| 16 |
+
x = np.array(img).astype(np.float32) / 255.0
|
| 17 |
+
x = np.transpose(x, (2, 0, 1)) # HWC -> CHW
|
| 18 |
+
x = np.expand_dims(x, 0) # add batch dimension -> (1,3,512,512)
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
# Example input
|
| 22 |
+
def load_thresholds(threshold_json_path, mode="balanced"):
|
| 23 |
+
"""
|
| 24 |
+
Loads thresholds from the given JSON file, using a particular mode
|
| 25 |
+
(e.g. 'balanced', 'high_precision', 'high_recall') for each category.
|
| 26 |
+
|
| 27 |
+
Returns:
|
| 28 |
+
thresholds_by_category (dict): e.g. { "general": 0.328..., "character": 0.304..., ... }
|
| 29 |
+
fallback_threshold (float): The overall threshold if category not found
|
| 30 |
+
"""
|
| 31 |
+
with open(threshold_json_path, "r", encoding="utf-8") as f:
|
| 32 |
+
data = json.load(f)
|
| 33 |
+
|
| 34 |
+
# The fallback threshold from the "overall" section for the chosen mode
|
| 35 |
+
fallback_threshold = data["overall"][mode]["threshold"]
|
| 36 |
+
|
| 37 |
+
# Build a dict of thresholds keyed by category
|
| 38 |
+
thresholds_by_category = {}
|
| 39 |
+
if "categories" in data:
|
| 40 |
+
for cat_name, cat_modes in data["categories"].items():
|
| 41 |
+
# If the chosen mode is present for that category, use it;
|
| 42 |
+
# otherwise fall back to the "overall" threshold.
|
| 43 |
+
if mode in cat_modes and "threshold" in cat_modes[mode]:
|
| 44 |
+
thresholds_by_category[cat_name] = cat_modes[mode]["threshold"]
|
| 45 |
+
else:
|
| 46 |
+
thresholds_by_category[cat_name] = fallback_threshold
|
| 47 |
+
|
| 48 |
+
return thresholds_by_category, fallback_threshold
|
| 49 |
+
|
| 50 |
+
def inference(
|
| 51 |
+
input_path,
|
| 52 |
+
output_format="verbose",
|
| 53 |
+
mode="balanced",
|
| 54 |
+
threshold_json_path="thresholds.json",
|
| 55 |
+
metadata_path="metadata.json"
|
| 56 |
+
):
|
| 57 |
+
"""
|
| 58 |
+
Run inference on an image using the loaded ONNX model, then apply
|
| 59 |
+
category-wise thresholds from `threshold.json` for the chosen mode.
|
| 60 |
+
|
| 61 |
+
Arguments:
|
| 62 |
+
input_path (str) : Path to the image file for inference.
|
| 63 |
+
output_format (str) : Either "verbose" or "as_prompt".
|
| 64 |
+
mode (str) : "balanced", "high_precision", or "high_recall"
|
| 65 |
+
threshold_json_path (str) : Path to the JSON file with category thresholds.
|
| 66 |
+
metadata_path (str) : Path to the metadata JSON file with category info.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
str: The predicted tags in either verbose or comma-separated format.
|
| 70 |
+
"""
|
| 71 |
+
# 1) Preprocess
|
| 72 |
+
input_tensor = preprocess_image(input_path)
|
| 73 |
+
|
| 74 |
+
# 2) Run inference
|
| 75 |
+
input_name = session.get_inputs()[0].name
|
| 76 |
+
outputs = session.run(None, {input_name: input_tensor})
|
| 77 |
+
initial_logits, refined_logits = outputs # shape: (1, 70527) each
|
| 78 |
+
|
| 79 |
+
# 3) Convert logits to probabilities
|
| 80 |
+
refined_probs = 1 / (1 + np.exp(-refined_logits)) # shape: (1, 70527)
|
| 81 |
+
|
| 82 |
+
# 4) Load metadata & retrieve threshold info
|
| 83 |
+
with open(metadata_path, "r", encoding="utf-8") as f:
|
| 84 |
+
metadata = json.load(f)
|
| 85 |
+
|
| 86 |
+
idx_to_tag = metadata["idx_to_tag"] # e.g. { "0": "brown_hair", "1": "blue_eyes", ... }
|
| 87 |
+
tag_to_category = metadata.get("tag_to_category", {})
|
| 88 |
+
# Load thresholds from threshold.json using the specified mode
|
| 89 |
+
thresholds_by_category, fallback_threshold = load_thresholds(threshold_json_path, mode)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# 5) Collect predictions by category
|
| 93 |
+
results_by_category = {}
|
| 94 |
+
num_tags = refined_probs.shape[1]
|
| 95 |
+
|
| 96 |
+
for i in range(num_tags):
|
| 97 |
+
prob = float(refined_probs[0, i])
|
| 98 |
+
tag_name = idx_to_tag[str(i)] # str(i) because metadata uses string keys
|
| 99 |
+
category = tag_to_category.get(tag_name, "general")
|
| 100 |
+
|
| 101 |
+
# Determine the threshold to use for this category
|
| 102 |
+
cat_threshold = thresholds_by_category.get(category, fallback_threshold)
|
| 103 |
+
|
| 104 |
+
if prob >= cat_threshold:
|
| 105 |
+
if category not in results_by_category:
|
| 106 |
+
results_by_category[category] = []
|
| 107 |
+
results_by_category[category].append((tag_name, prob))
|
| 108 |
+
|
| 109 |
+
# 6) Depending on output_format, produce different return strings
|
| 110 |
+
if output_format == "as_prompt":
|
| 111 |
+
# Flatten all predicted tags across categories
|
| 112 |
+
all_predicted_tags = []
|
| 113 |
+
for cat, tags_list in results_by_category.items():
|
| 114 |
+
# We only need the tag name in as_prompt format
|
| 115 |
+
for tname, tprob in tags_list:
|
| 116 |
+
# convert underscores to spaces
|
| 117 |
+
tag_name_spaces = tname.replace("_", " ")
|
| 118 |
+
all_predicted_tags.append(tag_name_spaces)
|
| 119 |
+
|
| 120 |
+
# Create a comma-separated string
|
| 121 |
+
prompt_string = ", ".join(all_predicted_tags)
|
| 122 |
+
return prompt_string
|
| 123 |
+
|
| 124 |
+
else: # "verbose"
|
| 125 |
+
# We'll build a multiline string describing the predictions
|
| 126 |
+
lines = []
|
| 127 |
+
lines.append("Predicted Tags by Category:\n")
|
| 128 |
+
for cat, tags_list in results_by_category.items():
|
| 129 |
+
lines.append(f"Category: {cat} | Predicted {len(tags_list)} tags")
|
| 130 |
+
# Sort descending by probability
|
| 131 |
+
for tname, tprob in sorted(tags_list, key=lambda x: x[1], reverse=True):
|
| 132 |
+
lines.append(f" Tag: {tname:30s} Prob: {tprob:.4f}")
|
| 133 |
+
lines.append("") # blank line after each category
|
| 134 |
+
# Join lines with newlines
|
| 135 |
+
verbose_output = "\n".join(lines)
|
| 136 |
+
return verbose_output
|
| 137 |
+
|
| 138 |
+
if __name__ == "__main__":
|
| 139 |
+
result = inference("", output_format="as_prompt")
|
| 140 |
print(result)
|
model_code.py
ADDED
|
@@ -0,0 +1,956 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torchvision.models import efficientnet_v2_l, EfficientNet_V2_L_Weights
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from typing import Optional
|
| 7 |
+
import torchvision.transforms as transforms
|
| 8 |
+
import os
|
| 9 |
+
import json
|
| 10 |
+
|
| 11 |
+
class InitialOnlyImageTagger(nn.Module):
|
| 12 |
+
"""
|
| 13 |
+
A lightweight version of ImageTagger that only includes the backbone and initial classifier.
|
| 14 |
+
This model uses significantly less VRAM than the full model.
|
| 15 |
+
"""
|
| 16 |
+
def __init__(self, total_tags, dataset, model_name='efficientnet_v2_l',
|
| 17 |
+
dropout=0.1, pretrained=True):
|
| 18 |
+
super().__init__()
|
| 19 |
+
# Debug and stats flags
|
| 20 |
+
self._flags = {
|
| 21 |
+
'debug': False,
|
| 22 |
+
'model_stats': False
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
# Core model config
|
| 26 |
+
self.dataset = dataset
|
| 27 |
+
self.embedding_dim = 1280 # Fixed to EfficientNetV2-L output dimension
|
| 28 |
+
|
| 29 |
+
# Initialize backbone
|
| 30 |
+
if model_name == 'efficientnet_v2_l':
|
| 31 |
+
weights = EfficientNet_V2_L_Weights.DEFAULT if pretrained else None
|
| 32 |
+
self.backbone = efficientnet_v2_l(weights=weights)
|
| 33 |
+
self.backbone.classifier = nn.Identity()
|
| 34 |
+
|
| 35 |
+
# Spatial pooling only - no projection
|
| 36 |
+
self.spatial_pool = nn.AdaptiveAvgPool2d((1, 1))
|
| 37 |
+
|
| 38 |
+
# Initial tag prediction with bottleneck
|
| 39 |
+
self.initial_classifier = nn.Sequential(
|
| 40 |
+
nn.Linear(self.embedding_dim, self.embedding_dim * 2),
|
| 41 |
+
nn.LayerNorm(self.embedding_dim * 2),
|
| 42 |
+
nn.GELU(),
|
| 43 |
+
nn.Dropout(dropout),
|
| 44 |
+
nn.Linear(self.embedding_dim * 2, self.embedding_dim),
|
| 45 |
+
nn.LayerNorm(self.embedding_dim),
|
| 46 |
+
nn.GELU(),
|
| 47 |
+
nn.Linear(self.embedding_dim, total_tags)
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Temperature scaling
|
| 51 |
+
self.temperature = nn.Parameter(torch.ones(1) * 1.5)
|
| 52 |
+
|
| 53 |
+
@property
|
| 54 |
+
def debug(self):
|
| 55 |
+
return self._flags['debug']
|
| 56 |
+
|
| 57 |
+
@debug.setter
|
| 58 |
+
def debug(self, value):
|
| 59 |
+
self._flags['debug'] = value
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def model_stats(self):
|
| 63 |
+
return self._flags['model_stats']
|
| 64 |
+
|
| 65 |
+
@model_stats.setter
|
| 66 |
+
def model_stats(self, value):
|
| 67 |
+
self._flags['model_stats'] = value
|
| 68 |
+
|
| 69 |
+
def preprocess_image(self, image_path, image_size=512):
|
| 70 |
+
"""Process an image for inference using same preprocessing as training"""
|
| 71 |
+
if not os.path.exists(image_path):
|
| 72 |
+
raise ValueError(f"Image not found at path: {image_path}")
|
| 73 |
+
|
| 74 |
+
# Initialize the same transform used during training
|
| 75 |
+
transform = transforms.Compose([
|
| 76 |
+
transforms.ToTensor(),
|
| 77 |
+
])
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
with Image.open(image_path) as img:
|
| 81 |
+
# Convert RGBA or Palette images to RGB
|
| 82 |
+
if img.mode in ('RGBA', 'P'):
|
| 83 |
+
img = img.convert('RGB')
|
| 84 |
+
|
| 85 |
+
# Get original dimensions
|
| 86 |
+
width, height = img.size
|
| 87 |
+
aspect_ratio = width / height
|
| 88 |
+
|
| 89 |
+
# Calculate new dimensions to maintain aspect ratio
|
| 90 |
+
if aspect_ratio > 1:
|
| 91 |
+
new_width = image_size
|
| 92 |
+
new_height = int(new_width / aspect_ratio)
|
| 93 |
+
else:
|
| 94 |
+
new_height = image_size
|
| 95 |
+
new_width = int(new_height * aspect_ratio)
|
| 96 |
+
|
| 97 |
+
# Resize with LANCZOS filter
|
| 98 |
+
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 99 |
+
|
| 100 |
+
# Create new image with padding
|
| 101 |
+
new_image = Image.new('RGB', (image_size, image_size), (0, 0, 0))
|
| 102 |
+
paste_x = (image_size - new_width) // 2
|
| 103 |
+
paste_y = (image_size - new_height) // 2
|
| 104 |
+
new_image.paste(img, (paste_x, paste_y))
|
| 105 |
+
|
| 106 |
+
# Apply transforms (without normalization)
|
| 107 |
+
img_tensor = transform(new_image)
|
| 108 |
+
return img_tensor
|
| 109 |
+
except Exception as e:
|
| 110 |
+
raise Exception(f"Error processing {image_path}: {str(e)}")
|
| 111 |
+
|
| 112 |
+
def forward(self, x):
|
| 113 |
+
"""Forward pass with only the initial predictions"""
|
| 114 |
+
# Image Feature Extraction
|
| 115 |
+
features = self.backbone.features(x)
|
| 116 |
+
features = self.spatial_pool(features).squeeze(-1).squeeze(-1)
|
| 117 |
+
|
| 118 |
+
# Initial Tag Predictions
|
| 119 |
+
initial_logits = self.initial_classifier(features)
|
| 120 |
+
initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0)
|
| 121 |
+
|
| 122 |
+
# For API compatibility with the full model, return the same predictions twice
|
| 123 |
+
return initial_preds, initial_preds
|
| 124 |
+
|
| 125 |
+
def predict(self, image_path, threshold=0.325, category_thresholds=None):
|
| 126 |
+
"""
|
| 127 |
+
Run inference on an image with support for category-specific thresholds.
|
| 128 |
+
"""
|
| 129 |
+
# Preprocess the image
|
| 130 |
+
img_tensor = self.preprocess_image(image_path).unsqueeze(0)
|
| 131 |
+
|
| 132 |
+
# Move to the same device as model and convert to half precision
|
| 133 |
+
device = next(self.parameters()).device
|
| 134 |
+
dtype = next(self.parameters()).dtype # Match model's precision
|
| 135 |
+
img_tensor = img_tensor.to(device, dtype=dtype)
|
| 136 |
+
|
| 137 |
+
# Run inference
|
| 138 |
+
with torch.no_grad():
|
| 139 |
+
initial_preds, _ = self.forward(img_tensor)
|
| 140 |
+
|
| 141 |
+
# Apply sigmoid to get probabilities
|
| 142 |
+
initial_probs = torch.sigmoid(initial_preds)
|
| 143 |
+
|
| 144 |
+
# Apply thresholds
|
| 145 |
+
if category_thresholds:
|
| 146 |
+
# Create binary prediction tensors
|
| 147 |
+
initial_binary = torch.zeros_like(initial_probs)
|
| 148 |
+
|
| 149 |
+
# Apply thresholds by category
|
| 150 |
+
for category, cat_threshold in category_thresholds.items():
|
| 151 |
+
# Create a mask for tags in this category
|
| 152 |
+
category_mask = torch.zeros_like(initial_probs, dtype=torch.bool)
|
| 153 |
+
|
| 154 |
+
# Find indices for this category
|
| 155 |
+
for tag_idx in range(initial_probs.size(-1)):
|
| 156 |
+
try:
|
| 157 |
+
_, tag_category = self.dataset.get_tag_info(tag_idx)
|
| 158 |
+
if tag_category == category:
|
| 159 |
+
category_mask[:, tag_idx] = True
|
| 160 |
+
except:
|
| 161 |
+
continue
|
| 162 |
+
|
| 163 |
+
# Apply threshold only to tags in this category
|
| 164 |
+
cat_threshold_tensor = torch.tensor(cat_threshold, device=device, dtype=dtype)
|
| 165 |
+
initial_binary[category_mask] = (initial_probs[category_mask] >= cat_threshold_tensor).to(dtype)
|
| 166 |
+
|
| 167 |
+
predictions = initial_binary
|
| 168 |
+
else:
|
| 169 |
+
# Use the same threshold for all tags
|
| 170 |
+
threshold_tensor = torch.tensor(threshold, device=device, dtype=dtype)
|
| 171 |
+
predictions = (initial_probs >= threshold_tensor).to(dtype)
|
| 172 |
+
|
| 173 |
+
# Return the same probabilities for both initial and refined for API compatibility
|
| 174 |
+
return {
|
| 175 |
+
'initial_probabilities': initial_probs,
|
| 176 |
+
'refined_probabilities': initial_probs, # Same as initial for compatibility
|
| 177 |
+
'predictions': predictions
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
def get_tags_from_predictions(self, predictions, include_probabilities=True):
|
| 181 |
+
"""
|
| 182 |
+
Convert model predictions to human-readable tags grouped by category.
|
| 183 |
+
"""
|
| 184 |
+
# Get non-zero predictions
|
| 185 |
+
if predictions.dim() > 1:
|
| 186 |
+
predictions = predictions[0] # Remove batch dimension
|
| 187 |
+
|
| 188 |
+
# Get indices of positive predictions
|
| 189 |
+
indices = torch.where(predictions > 0)[0].cpu().tolist()
|
| 190 |
+
|
| 191 |
+
# Group by category
|
| 192 |
+
result = {}
|
| 193 |
+
for idx in indices:
|
| 194 |
+
tag_name, category = self.dataset.get_tag_info(idx)
|
| 195 |
+
|
| 196 |
+
if category not in result:
|
| 197 |
+
result[category] = []
|
| 198 |
+
|
| 199 |
+
if include_probabilities:
|
| 200 |
+
prob = predictions[idx].item()
|
| 201 |
+
result[category].append((tag_name, prob))
|
| 202 |
+
else:
|
| 203 |
+
result[category].append(tag_name)
|
| 204 |
+
|
| 205 |
+
# Sort tags by probability within each category
|
| 206 |
+
if include_probabilities:
|
| 207 |
+
for category in result:
|
| 208 |
+
result[category] = sorted(result[category], key=lambda x: x[1], reverse=True)
|
| 209 |
+
|
| 210 |
+
return result
|
| 211 |
+
|
| 212 |
+
class FlashAttention(nn.Module):
|
| 213 |
+
def __init__(self, dim, num_heads=8, dropout=0.1, batch_first=True):
|
| 214 |
+
super().__init__()
|
| 215 |
+
self.dim = dim
|
| 216 |
+
self.num_heads = num_heads
|
| 217 |
+
self.dropout = dropout
|
| 218 |
+
self.batch_first = batch_first
|
| 219 |
+
self.head_dim = dim // num_heads
|
| 220 |
+
assert self.head_dim * num_heads == dim, "dim must be divisible by num_heads"
|
| 221 |
+
|
| 222 |
+
self.q_proj = nn.Linear(dim, dim, bias=False)
|
| 223 |
+
self.k_proj = nn.Linear(dim, dim, bias=False)
|
| 224 |
+
self.v_proj = nn.Linear(dim, dim, bias=False)
|
| 225 |
+
self.out_proj = nn.Linear(dim, dim, bias=False)
|
| 226 |
+
|
| 227 |
+
for proj in [self.q_proj, self.k_proj, self.v_proj, self.out_proj]:
|
| 228 |
+
nn.init.xavier_uniform_(proj.weight, gain=0.1)
|
| 229 |
+
|
| 230 |
+
self.scale = self.head_dim ** -0.5
|
| 231 |
+
self.debug = False
|
| 232 |
+
|
| 233 |
+
def _debug_print(self, name, tensor):
|
| 234 |
+
"""Debug helper"""
|
| 235 |
+
if self.debug:
|
| 236 |
+
print(f"\n{name}:")
|
| 237 |
+
print(f"Shape: {tensor.shape}")
|
| 238 |
+
print(f"Device: {tensor.device}")
|
| 239 |
+
print(f"Dtype: {tensor.dtype}")
|
| 240 |
+
if tensor.is_floating_point():
|
| 241 |
+
with torch.no_grad():
|
| 242 |
+
print(f"Range: [{tensor.min().item():.3f}, {tensor.max().item():.3f}]")
|
| 243 |
+
print(f"Mean: {tensor.mean().item():.3f}")
|
| 244 |
+
print(f"Std: {tensor.std().item():.3f}")
|
| 245 |
+
|
| 246 |
+
def _reshape_for_flash(self, x: torch.Tensor) -> torch.Tensor:
|
| 247 |
+
"""Reshape input tensor for flash attention format"""
|
| 248 |
+
batch_size, seq_len, _ = x.size()
|
| 249 |
+
x = x.view(batch_size, seq_len, self.num_heads, self.head_dim)
|
| 250 |
+
x = x.transpose(1, 2) # [B, H, S, D]
|
| 251 |
+
return x.contiguous()
|
| 252 |
+
|
| 253 |
+
def forward(self, query: torch.Tensor, key: Optional[torch.Tensor] = None,
|
| 254 |
+
value: Optional[torch.Tensor] = None,
|
| 255 |
+
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 256 |
+
"""Forward pass with flash attention"""
|
| 257 |
+
if self.debug:
|
| 258 |
+
print("\nFlashAttention Forward Pass")
|
| 259 |
+
|
| 260 |
+
batch_size = query.size(0)
|
| 261 |
+
|
| 262 |
+
# Use query as key/value if not provided
|
| 263 |
+
key = query if key is None else key
|
| 264 |
+
value = query if value is None else value
|
| 265 |
+
|
| 266 |
+
# Project inputs
|
| 267 |
+
q = self.q_proj(query)
|
| 268 |
+
k = self.k_proj(key)
|
| 269 |
+
v = self.v_proj(value)
|
| 270 |
+
|
| 271 |
+
if self.debug:
|
| 272 |
+
self._debug_print("Query before reshape", q)
|
| 273 |
+
|
| 274 |
+
# Reshape for attention [B, H, S, D]
|
| 275 |
+
q = self._reshape_for_flash(q)
|
| 276 |
+
k = self._reshape_for_flash(k)
|
| 277 |
+
v = self._reshape_for_flash(v)
|
| 278 |
+
|
| 279 |
+
if self.debug:
|
| 280 |
+
self._debug_print("Query after reshape", q)
|
| 281 |
+
|
| 282 |
+
# Handle masking
|
| 283 |
+
if mask is not None:
|
| 284 |
+
# First convert mask to proper shape based on input dimensionality
|
| 285 |
+
if mask.dim() == 2: # [B, S]
|
| 286 |
+
mask = mask.view(batch_size, 1, -1, 1)
|
| 287 |
+
elif mask.dim() == 3: # [B, S, S]
|
| 288 |
+
mask = mask.view(batch_size, 1, mask.size(1), mask.size(2))
|
| 289 |
+
elif mask.dim() == 5: # [B, 1, S, S, S]
|
| 290 |
+
mask = mask.squeeze(1).view(batch_size, 1, mask.size(2), mask.size(3))
|
| 291 |
+
|
| 292 |
+
# Ensure mask is float16 if we're using float16
|
| 293 |
+
mask = mask.to(q.dtype)
|
| 294 |
+
|
| 295 |
+
if self.debug:
|
| 296 |
+
self._debug_print("Prepared mask", mask)
|
| 297 |
+
print(f"q shape: {q.shape}, mask shape: {mask.shape}")
|
| 298 |
+
|
| 299 |
+
# Create attention mask that covers the full sequence length
|
| 300 |
+
seq_len = q.size(2)
|
| 301 |
+
if mask.size(-1) != seq_len:
|
| 302 |
+
# Pad or trim mask to match sequence length
|
| 303 |
+
new_mask = torch.zeros(batch_size, 1, seq_len, seq_len,
|
| 304 |
+
device=mask.device, dtype=mask.dtype)
|
| 305 |
+
min_len = min(seq_len, mask.size(-1))
|
| 306 |
+
new_mask[..., :min_len, :min_len] = mask[..., :min_len, :min_len]
|
| 307 |
+
mask = new_mask
|
| 308 |
+
|
| 309 |
+
# Create key padding mask
|
| 310 |
+
key_padding_mask = mask.squeeze(1).sum(-1) > 0
|
| 311 |
+
key_padding_mask = key_padding_mask.view(batch_size, 1, -1, 1)
|
| 312 |
+
|
| 313 |
+
# Apply the key padding mask
|
| 314 |
+
k = k * key_padding_mask
|
| 315 |
+
v = v * key_padding_mask
|
| 316 |
+
|
| 317 |
+
if self.debug:
|
| 318 |
+
self._debug_print("Query before attention", q)
|
| 319 |
+
self._debug_print("Key before attention", k)
|
| 320 |
+
self._debug_print("Value before attention", v)
|
| 321 |
+
|
| 322 |
+
# Run flash attention
|
| 323 |
+
dropout_p = self.dropout if self.training else 0.0
|
| 324 |
+
output = flash_attn_func(
|
| 325 |
+
q, k, v,
|
| 326 |
+
dropout_p=dropout_p,
|
| 327 |
+
softmax_scale=self.scale,
|
| 328 |
+
causal=False
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
if self.debug:
|
| 332 |
+
self._debug_print("Output after attention", output)
|
| 333 |
+
|
| 334 |
+
# Reshape output [B, H, S, D] -> [B, S, H, D] -> [B, S, D]
|
| 335 |
+
output = output.transpose(1, 2).contiguous()
|
| 336 |
+
output = output.view(batch_size, -1, self.dim)
|
| 337 |
+
|
| 338 |
+
# Final projection
|
| 339 |
+
output = self.out_proj(output)
|
| 340 |
+
|
| 341 |
+
if self.debug:
|
| 342 |
+
self._debug_print("Final output", output)
|
| 343 |
+
|
| 344 |
+
return output
|
| 345 |
+
|
| 346 |
+
class OptimizedTagEmbedding(nn.Module):
|
| 347 |
+
def __init__(self, num_tags, embedding_dim, num_heads=8, dropout=0.1):
|
| 348 |
+
super().__init__()
|
| 349 |
+
# Single shared embedding for all tags
|
| 350 |
+
self.embedding = nn.Embedding(num_tags, embedding_dim)
|
| 351 |
+
self.attention = FlashAttention(embedding_dim, num_heads, dropout)
|
| 352 |
+
self.norm1 = nn.LayerNorm(embedding_dim)
|
| 353 |
+
self.norm2 = nn.LayerNorm(embedding_dim)
|
| 354 |
+
|
| 355 |
+
# Single importance weighting for all tags
|
| 356 |
+
self.tag_importance = nn.Parameter(torch.ones(num_tags) * 0.1)
|
| 357 |
+
|
| 358 |
+
# Projection layers for unified tag context
|
| 359 |
+
self.context_proj = nn.Sequential(
|
| 360 |
+
nn.Linear(embedding_dim, embedding_dim * 2),
|
| 361 |
+
nn.LayerNorm(embedding_dim * 2),
|
| 362 |
+
nn.GELU(),
|
| 363 |
+
nn.Dropout(dropout),
|
| 364 |
+
nn.Linear(embedding_dim * 2, embedding_dim),
|
| 365 |
+
nn.LayerNorm(embedding_dim)
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
self.importance_scale = nn.Parameter(torch.tensor(0.1))
|
| 369 |
+
self.context_scale = nn.Parameter(torch.tensor(1.0))
|
| 370 |
+
self.debug = False
|
| 371 |
+
|
| 372 |
+
def _debug_print(self, name, tensor, extra_info=None):
|
| 373 |
+
"""Memory efficient debug printing with type handling"""
|
| 374 |
+
if self.debug:
|
| 375 |
+
print(f"\n{name}:")
|
| 376 |
+
print(f"- Shape: {tensor.shape}")
|
| 377 |
+
if isinstance(tensor, torch.Tensor):
|
| 378 |
+
with torch.no_grad():
|
| 379 |
+
print(f"- Device: {tensor.device}")
|
| 380 |
+
print(f"- Dtype: {tensor.dtype}")
|
| 381 |
+
|
| 382 |
+
# Convert to float32 for statistics if needed
|
| 383 |
+
if tensor.dtype not in [torch.float16, torch.float32, torch.float64]:
|
| 384 |
+
calc_tensor = tensor.float()
|
| 385 |
+
else:
|
| 386 |
+
calc_tensor = tensor
|
| 387 |
+
|
| 388 |
+
try:
|
| 389 |
+
min_val = calc_tensor.min().item()
|
| 390 |
+
max_val = calc_tensor.max().item()
|
| 391 |
+
mean_val = calc_tensor.mean().item()
|
| 392 |
+
std_val = calc_tensor.std().item()
|
| 393 |
+
norm_val = torch.norm(calc_tensor).item()
|
| 394 |
+
|
| 395 |
+
print(f"- Value range: [{min_val:.3f}, {max_val:.3f}]")
|
| 396 |
+
print(f"- Mean: {mean_val:.3f}")
|
| 397 |
+
print(f"- Std: {std_val:.3f}")
|
| 398 |
+
print(f"- L2 Norm: {norm_val:.3f}")
|
| 399 |
+
|
| 400 |
+
if extra_info:
|
| 401 |
+
print(f"- Additional info: {extra_info}")
|
| 402 |
+
except Exception as e:
|
| 403 |
+
print(f"- Could not compute statistics: {str(e)}")
|
| 404 |
+
|
| 405 |
+
def _debug_tensor(self, name, tensor):
|
| 406 |
+
"""Debug helper with dtype-specific analysis"""
|
| 407 |
+
if self.debug and isinstance(tensor, torch.Tensor):
|
| 408 |
+
print(f"\n{name}:")
|
| 409 |
+
print(f"- Shape: {tensor.shape}")
|
| 410 |
+
print(f"- Device: {tensor.device}")
|
| 411 |
+
print(f"- Dtype: {tensor.dtype}")
|
| 412 |
+
with torch.no_grad():
|
| 413 |
+
has_nan = torch.isnan(tensor).any().item() if tensor.is_floating_point() else False
|
| 414 |
+
has_inf = torch.isinf(tensor).any().item() if tensor.is_floating_point() else False
|
| 415 |
+
print(f"- Contains NaN: {has_nan}")
|
| 416 |
+
print(f"- Contains Inf: {has_inf}")
|
| 417 |
+
|
| 418 |
+
# Different stats for different dtypes
|
| 419 |
+
if tensor.is_floating_point():
|
| 420 |
+
print(f"- Range: [{tensor.min().item():.3f}, {tensor.max().item():.3f}]")
|
| 421 |
+
print(f"- Mean: {tensor.mean().item():.3f}")
|
| 422 |
+
print(f"- Std: {tensor.std().item():.3f}")
|
| 423 |
+
else:
|
| 424 |
+
# For integer tensors
|
| 425 |
+
print(f"- Range: [{tensor.min().item()}, {tensor.max().item()}]")
|
| 426 |
+
print(f"- Unique values: {tensor.unique().numel()}")
|
| 427 |
+
|
| 428 |
+
def _process_category(self, indices, masks):
|
| 429 |
+
"""Process a single category of tags"""
|
| 430 |
+
# Get embeddings for this category
|
| 431 |
+
embeddings = self.embedding(indices)
|
| 432 |
+
|
| 433 |
+
if self.debug:
|
| 434 |
+
self._debug_tensor("Category embeddings", embeddings)
|
| 435 |
+
|
| 436 |
+
# Apply importance weights
|
| 437 |
+
importance = torch.sigmoid(self.tag_importance) * self.importance_scale
|
| 438 |
+
importance = torch.clamp(importance, min=0.01, max=10.0)
|
| 439 |
+
importance_weights = importance[indices].unsqueeze(-1)
|
| 440 |
+
|
| 441 |
+
# Apply and normalize
|
| 442 |
+
embeddings = embeddings * importance_weights
|
| 443 |
+
embeddings = self.norm1(embeddings)
|
| 444 |
+
|
| 445 |
+
# Apply attention if we have more than one tag
|
| 446 |
+
if embeddings.size(1) > 1:
|
| 447 |
+
if masks is not None:
|
| 448 |
+
attention_mask = torch.einsum('bi,bj->bij', masks, masks)
|
| 449 |
+
attended = self.attention(embeddings, mask=attention_mask)
|
| 450 |
+
else:
|
| 451 |
+
attended = self.attention(embeddings)
|
| 452 |
+
embeddings = self.norm2(attended)
|
| 453 |
+
|
| 454 |
+
# Pool embeddings with masking
|
| 455 |
+
if masks is not None:
|
| 456 |
+
masked_embeddings = embeddings * masks.unsqueeze(-1)
|
| 457 |
+
pooled = masked_embeddings.sum(dim=1) / masks.sum(dim=1, keepdim=True).clamp(min=1.0)
|
| 458 |
+
else:
|
| 459 |
+
pooled = embeddings.mean(dim=1)
|
| 460 |
+
|
| 461 |
+
return pooled, embeddings
|
| 462 |
+
|
| 463 |
+
def forward(self, tag_indices_dict, tag_masks_dict=None):
|
| 464 |
+
"""
|
| 465 |
+
Process all tags in a unified embedding space
|
| 466 |
+
Args:
|
| 467 |
+
tag_indices_dict: dict of {category: tensor of indices}
|
| 468 |
+
tag_masks_dict: dict of {category: tensor of masks}
|
| 469 |
+
"""
|
| 470 |
+
if self.debug:
|
| 471 |
+
print("\nOptimizedTagEmbedding Forward Pass")
|
| 472 |
+
|
| 473 |
+
# Concatenate all indices and masks
|
| 474 |
+
all_indices = []
|
| 475 |
+
all_masks = []
|
| 476 |
+
batch_size = None
|
| 477 |
+
|
| 478 |
+
for category, indices in tag_indices_dict.items():
|
| 479 |
+
if batch_size is None:
|
| 480 |
+
batch_size = indices.size(0)
|
| 481 |
+
all_indices.append(indices)
|
| 482 |
+
if tag_masks_dict:
|
| 483 |
+
all_masks.append(tag_masks_dict[category])
|
| 484 |
+
|
| 485 |
+
# Stack along sequence dimension
|
| 486 |
+
combined_indices = torch.cat(all_indices, dim=1) # [B, total_seq_len]
|
| 487 |
+
if tag_masks_dict:
|
| 488 |
+
combined_masks = torch.cat(all_masks, dim=1) # [B, total_seq_len]
|
| 489 |
+
|
| 490 |
+
if self.debug:
|
| 491 |
+
self._debug_tensor("Combined indices", combined_indices)
|
| 492 |
+
if tag_masks_dict:
|
| 493 |
+
self._debug_tensor("Combined masks", combined_masks)
|
| 494 |
+
|
| 495 |
+
# Get embeddings for all tags using shared embedding
|
| 496 |
+
embeddings = self.embedding(combined_indices) # [B, total_seq_len, D]
|
| 497 |
+
|
| 498 |
+
if self.debug:
|
| 499 |
+
self._debug_tensor("Base embeddings", embeddings)
|
| 500 |
+
|
| 501 |
+
# Apply unified importance weighting
|
| 502 |
+
importance = torch.sigmoid(self.tag_importance) * self.importance_scale
|
| 503 |
+
importance = torch.clamp(importance, min=0.01, max=10.0)
|
| 504 |
+
importance_weights = importance[combined_indices].unsqueeze(-1)
|
| 505 |
+
|
| 506 |
+
# Apply and normalize importance weights
|
| 507 |
+
embeddings = embeddings * importance_weights
|
| 508 |
+
embeddings = self.norm1(embeddings)
|
| 509 |
+
|
| 510 |
+
if self.debug:
|
| 511 |
+
self._debug_tensor("Weighted embeddings", embeddings)
|
| 512 |
+
|
| 513 |
+
# Apply attention across all tags together
|
| 514 |
+
if tag_masks_dict:
|
| 515 |
+
attention_mask = torch.einsum('bi,bj->bij', combined_masks, combined_masks)
|
| 516 |
+
attended = self.attention(embeddings, mask=attention_mask)
|
| 517 |
+
else:
|
| 518 |
+
attended = self.attention(embeddings)
|
| 519 |
+
|
| 520 |
+
attended = self.norm2(attended)
|
| 521 |
+
|
| 522 |
+
if self.debug:
|
| 523 |
+
self._debug_tensor("Attended embeddings", attended)
|
| 524 |
+
|
| 525 |
+
# Global pooling with masking
|
| 526 |
+
if tag_masks_dict:
|
| 527 |
+
masked_embeddings = attended * combined_masks.unsqueeze(-1)
|
| 528 |
+
tag_context = masked_embeddings.sum(dim=1) / combined_masks.sum(dim=1, keepdim=True).clamp(min=1.0)
|
| 529 |
+
else:
|
| 530 |
+
tag_context = attended.mean(dim=1)
|
| 531 |
+
|
| 532 |
+
# Project and scale context
|
| 533 |
+
tag_context = self.context_proj(tag_context)
|
| 534 |
+
context_scale = torch.clamp(self.context_scale, min=0.1, max=10.0)
|
| 535 |
+
tag_context = tag_context * context_scale
|
| 536 |
+
|
| 537 |
+
if self.debug:
|
| 538 |
+
self._debug_tensor("Final tag context", tag_context)
|
| 539 |
+
|
| 540 |
+
return tag_context, attended
|
| 541 |
+
|
| 542 |
+
class TagDataset:
|
| 543 |
+
"""Lightweight dataset wrapper for inference only"""
|
| 544 |
+
def __init__(self, total_tags, idx_to_tag, tag_to_category):
|
| 545 |
+
self.total_tags = total_tags
|
| 546 |
+
self.idx_to_tag = idx_to_tag if isinstance(idx_to_tag, dict) else {int(k): v for k, v in idx_to_tag.items()}
|
| 547 |
+
self.tag_to_category = tag_to_category
|
| 548 |
+
|
| 549 |
+
def get_tag_info(self, idx):
|
| 550 |
+
"""Get tag name and category for a given index"""
|
| 551 |
+
tag_name = self.idx_to_tag.get(idx, f"unknown-{idx}")
|
| 552 |
+
category = self.tag_to_category.get(tag_name, "general")
|
| 553 |
+
return tag_name, category
|
| 554 |
+
|
| 555 |
+
class ImageTagger(nn.Module):
|
| 556 |
+
def __init__(self, total_tags, dataset, model_name='efficientnet_v2_l',
|
| 557 |
+
num_heads=16, dropout=0.1, pretrained=True,
|
| 558 |
+
tag_context_size=256):
|
| 559 |
+
super().__init__()
|
| 560 |
+
# Debug and stats flags
|
| 561 |
+
self._flags = {
|
| 562 |
+
'debug': False,
|
| 563 |
+
'model_stats': False
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
# Core model config
|
| 567 |
+
self.dataset = dataset
|
| 568 |
+
self.tag_context_size = tag_context_size
|
| 569 |
+
self.embedding_dim = 1280 # Fixed to EfficientNetV2-L output dimension
|
| 570 |
+
|
| 571 |
+
# Initialize backbone
|
| 572 |
+
if model_name == 'efficientnet_v2_l':
|
| 573 |
+
weights = EfficientNet_V2_L_Weights.DEFAULT if pretrained else None
|
| 574 |
+
self.backbone = efficientnet_v2_l(weights=weights)
|
| 575 |
+
self.backbone.classifier = nn.Identity()
|
| 576 |
+
|
| 577 |
+
# Spatial pooling only - no projection
|
| 578 |
+
self.spatial_pool = nn.AdaptiveAvgPool2d((1, 1))
|
| 579 |
+
|
| 580 |
+
# Initial tag prediction with bottleneck
|
| 581 |
+
self.initial_classifier = nn.Sequential(
|
| 582 |
+
nn.Linear(self.embedding_dim, self.embedding_dim * 2),
|
| 583 |
+
nn.LayerNorm(self.embedding_dim * 2),
|
| 584 |
+
nn.GELU(),
|
| 585 |
+
nn.Dropout(dropout),
|
| 586 |
+
nn.Linear(self.embedding_dim * 2, self.embedding_dim),
|
| 587 |
+
nn.LayerNorm(self.embedding_dim),
|
| 588 |
+
nn.GELU(),
|
| 589 |
+
nn.Linear(self.embedding_dim, total_tags)
|
| 590 |
+
)
|
| 591 |
+
|
| 592 |
+
# Tag embeddings at full dimension
|
| 593 |
+
self.tag_embedding = nn.Embedding(total_tags, self.embedding_dim)
|
| 594 |
+
self.tag_attention = FlashAttention(self.embedding_dim, num_heads, dropout)
|
| 595 |
+
self.tag_norm = nn.LayerNorm(self.embedding_dim)
|
| 596 |
+
|
| 597 |
+
# Improved cross attention projection
|
| 598 |
+
self.cross_proj = nn.Sequential(
|
| 599 |
+
nn.Linear(self.embedding_dim, self.embedding_dim * 2),
|
| 600 |
+
nn.LayerNorm(self.embedding_dim * 2),
|
| 601 |
+
nn.GELU(),
|
| 602 |
+
nn.Dropout(dropout),
|
| 603 |
+
nn.Linear(self.embedding_dim * 2, self.embedding_dim)
|
| 604 |
+
)
|
| 605 |
+
|
| 606 |
+
# Cross attention at full dimension
|
| 607 |
+
self.cross_attention = FlashAttention(self.embedding_dim, num_heads, dropout)
|
| 608 |
+
self.cross_norm = nn.LayerNorm(self.embedding_dim)
|
| 609 |
+
|
| 610 |
+
# Refined classifier with improved bottleneck
|
| 611 |
+
self.refined_classifier = nn.Sequential(
|
| 612 |
+
nn.Linear(self.embedding_dim * 2, self.embedding_dim * 2), # Doubled input size for residual
|
| 613 |
+
nn.LayerNorm(self.embedding_dim * 2),
|
| 614 |
+
nn.GELU(),
|
| 615 |
+
nn.Dropout(dropout),
|
| 616 |
+
nn.Linear(self.embedding_dim * 2, self.embedding_dim),
|
| 617 |
+
nn.LayerNorm(self.embedding_dim),
|
| 618 |
+
nn.GELU(),
|
| 619 |
+
nn.Linear(self.embedding_dim, total_tags)
|
| 620 |
+
)
|
| 621 |
+
|
| 622 |
+
# Temperature scaling
|
| 623 |
+
self.temperature = nn.Parameter(torch.ones(1) * 1.5)
|
| 624 |
+
|
| 625 |
+
def _get_selected_tags(self, logits):
|
| 626 |
+
"""Select top-K tags based on prediction confidence"""
|
| 627 |
+
# Apply sigmoid to get probabilities
|
| 628 |
+
probs = torch.sigmoid(logits)
|
| 629 |
+
|
| 630 |
+
# Get top-K predictions for each image in batch
|
| 631 |
+
batch_size = logits.size(0)
|
| 632 |
+
topk_values, topk_indices = torch.topk(
|
| 633 |
+
probs, k=self.tag_context_size, dim=1, largest=True, sorted=True
|
| 634 |
+
)
|
| 635 |
+
|
| 636 |
+
return topk_indices, topk_values
|
| 637 |
+
|
| 638 |
+
@property
|
| 639 |
+
def debug(self):
|
| 640 |
+
return self._flags['debug']
|
| 641 |
+
|
| 642 |
+
@debug.setter
|
| 643 |
+
def debug(self, value):
|
| 644 |
+
self._flags['debug'] = value
|
| 645 |
+
|
| 646 |
+
@property
|
| 647 |
+
def model_stats(self):
|
| 648 |
+
return self._flags['model_stats']
|
| 649 |
+
|
| 650 |
+
@model_stats.setter
|
| 651 |
+
def model_stats(self, value):
|
| 652 |
+
self._flags['model_stats'] = value
|
| 653 |
+
|
| 654 |
+
def preprocess_image(self, image_path, image_size=512):
|
| 655 |
+
"""Process an image for inference using same preprocessing as training"""
|
| 656 |
+
if not os.path.exists(image_path):
|
| 657 |
+
raise ValueError(f"Image not found at path: {image_path}")
|
| 658 |
+
|
| 659 |
+
# Initialize the same transform used during training
|
| 660 |
+
transform = transforms.Compose([
|
| 661 |
+
transforms.ToTensor(),
|
| 662 |
+
])
|
| 663 |
+
|
| 664 |
+
try:
|
| 665 |
+
with Image.open(image_path) as img:
|
| 666 |
+
# Convert RGBA or Palette images to RGB
|
| 667 |
+
if img.mode in ('RGBA', 'P'):
|
| 668 |
+
img = img.convert('RGB')
|
| 669 |
+
|
| 670 |
+
# Get original dimensions
|
| 671 |
+
width, height = img.size
|
| 672 |
+
aspect_ratio = width / height
|
| 673 |
+
|
| 674 |
+
# Calculate new dimensions to maintain aspect ratio
|
| 675 |
+
if aspect_ratio > 1:
|
| 676 |
+
new_width = image_size
|
| 677 |
+
new_height = int(new_width / aspect_ratio)
|
| 678 |
+
else:
|
| 679 |
+
new_height = image_size
|
| 680 |
+
new_width = int(new_height * aspect_ratio)
|
| 681 |
+
|
| 682 |
+
# Resize with LANCZOS filter
|
| 683 |
+
img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
|
| 684 |
+
|
| 685 |
+
# Create new image with padding
|
| 686 |
+
new_image = Image.new('RGB', (image_size, image_size), (0, 0, 0))
|
| 687 |
+
paste_x = (image_size - new_width) // 2
|
| 688 |
+
paste_y = (image_size - new_height) // 2
|
| 689 |
+
new_image.paste(img, (paste_x, paste_y))
|
| 690 |
+
|
| 691 |
+
# Apply transforms (without normalization)
|
| 692 |
+
img_tensor = transform(new_image)
|
| 693 |
+
return img_tensor
|
| 694 |
+
except Exception as e:
|
| 695 |
+
raise Exception(f"Error processing {image_path}: {str(e)}")
|
| 696 |
+
|
| 697 |
+
def forward(self, x):
|
| 698 |
+
"""Forward pass with simplified feature handling"""
|
| 699 |
+
# Initialize tracking dicts
|
| 700 |
+
model_stats = {} if self.model_stats else {}
|
| 701 |
+
debug_tensors = {} if self.debug else None
|
| 702 |
+
|
| 703 |
+
# 1. Image Feature Extraction
|
| 704 |
+
features = self.backbone.features(x)
|
| 705 |
+
features = self.spatial_pool(features).squeeze(-1).squeeze(-1)
|
| 706 |
+
|
| 707 |
+
# 2. Initial Tag Predictions
|
| 708 |
+
initial_logits = self.initial_classifier(features)
|
| 709 |
+
initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0)
|
| 710 |
+
|
| 711 |
+
# 3. Tag Selection & Embedding (simplified)
|
| 712 |
+
pred_tag_indices, _ = self._get_selected_tags(initial_preds)
|
| 713 |
+
tag_embeddings = self.tag_embedding(pred_tag_indices)
|
| 714 |
+
|
| 715 |
+
# 4. Self-Attention on Tags
|
| 716 |
+
attended_tags = self.tag_attention(tag_embeddings)
|
| 717 |
+
attended_tags = self.tag_norm(attended_tags)
|
| 718 |
+
|
| 719 |
+
# 5. Cross-Attention between Features and Tags
|
| 720 |
+
features_proj = self.cross_proj(features)
|
| 721 |
+
features_expanded = features_proj.unsqueeze(1).expand(-1, self.tag_context_size, -1)
|
| 722 |
+
|
| 723 |
+
cross_attended = self.cross_attention(features_expanded, attended_tags)
|
| 724 |
+
cross_attended = self.cross_norm(cross_attended)
|
| 725 |
+
|
| 726 |
+
# 6. Feature Fusion with Residual Connection
|
| 727 |
+
fused_features = cross_attended.mean(dim=1) # Average across tag dimension
|
| 728 |
+
# Concatenate original and attended features
|
| 729 |
+
combined_features = torch.cat([features, fused_features], dim=-1)
|
| 730 |
+
|
| 731 |
+
# 7. Refined Predictions
|
| 732 |
+
refined_logits = self.refined_classifier(combined_features)
|
| 733 |
+
refined_preds = torch.clamp(refined_logits / self.temperature, min=-15.0, max=15.0)
|
| 734 |
+
|
| 735 |
+
# Return both prediction sets
|
| 736 |
+
return initial_preds, refined_preds
|
| 737 |
+
|
| 738 |
+
def predict(self, image_path, threshold=0.325, category_thresholds=None):
|
| 739 |
+
"""
|
| 740 |
+
Run inference on an image with support for category-specific thresholds.
|
| 741 |
+
"""
|
| 742 |
+
# Preprocess the image
|
| 743 |
+
img_tensor = self.preprocess_image(image_path).unsqueeze(0)
|
| 744 |
+
|
| 745 |
+
# Move to the same device as model and convert to half precision
|
| 746 |
+
device = next(self.parameters()).device
|
| 747 |
+
dtype = next(self.parameters()).dtype # Match model's precision
|
| 748 |
+
img_tensor = img_tensor.to(device, dtype=dtype)
|
| 749 |
+
|
| 750 |
+
# Run inference
|
| 751 |
+
with torch.no_grad():
|
| 752 |
+
initial_preds, refined_preds = self.forward(img_tensor)
|
| 753 |
+
|
| 754 |
+
# Apply sigmoid to get probabilities
|
| 755 |
+
initial_probs = torch.sigmoid(initial_preds)
|
| 756 |
+
refined_probs = torch.sigmoid(refined_preds)
|
| 757 |
+
|
| 758 |
+
# Apply thresholds
|
| 759 |
+
if category_thresholds:
|
| 760 |
+
# Create binary prediction tensors
|
| 761 |
+
refined_binary = torch.zeros_like(refined_probs)
|
| 762 |
+
|
| 763 |
+
# Apply thresholds by category
|
| 764 |
+
for category, cat_threshold in category_thresholds.items():
|
| 765 |
+
# Create a mask for tags in this category
|
| 766 |
+
category_mask = torch.zeros_like(refined_probs, dtype=torch.bool)
|
| 767 |
+
|
| 768 |
+
# Find indices for this category
|
| 769 |
+
for tag_idx in range(refined_probs.size(-1)):
|
| 770 |
+
try:
|
| 771 |
+
_, tag_category = self.dataset.get_tag_info(tag_idx)
|
| 772 |
+
if tag_category == category:
|
| 773 |
+
category_mask[:, tag_idx] = True
|
| 774 |
+
except:
|
| 775 |
+
continue
|
| 776 |
+
|
| 777 |
+
# Apply threshold only to tags in this category - ensure dtype consistency
|
| 778 |
+
cat_threshold_tensor = torch.tensor(cat_threshold, device=device, dtype=dtype)
|
| 779 |
+
refined_binary[category_mask] = (refined_probs[category_mask] >= cat_threshold_tensor).to(dtype)
|
| 780 |
+
|
| 781 |
+
predictions = refined_binary
|
| 782 |
+
else:
|
| 783 |
+
# Use the same threshold for all tags
|
| 784 |
+
threshold_tensor = torch.tensor(threshold, device=device, dtype=dtype)
|
| 785 |
+
predictions = (refined_probs >= threshold_tensor).to(dtype)
|
| 786 |
+
|
| 787 |
+
# Return both probabilities and thresholded predictions
|
| 788 |
+
return {
|
| 789 |
+
'initial_probabilities': initial_probs,
|
| 790 |
+
'refined_probabilities': refined_probs,
|
| 791 |
+
'predictions': predictions
|
| 792 |
+
}
|
| 793 |
+
|
| 794 |
+
def get_tags_from_predictions(self, predictions, include_probabilities=True):
|
| 795 |
+
"""
|
| 796 |
+
Convert model predictions to human-readable tags grouped by category.
|
| 797 |
+
"""
|
| 798 |
+
# Get non-zero predictions
|
| 799 |
+
if predictions.dim() > 1:
|
| 800 |
+
predictions = predictions[0] # Remove batch dimension
|
| 801 |
+
|
| 802 |
+
# Get indices of positive predictions
|
| 803 |
+
indices = torch.where(predictions > 0)[0].cpu().tolist()
|
| 804 |
+
|
| 805 |
+
# Group by category
|
| 806 |
+
result = {}
|
| 807 |
+
for idx in indices:
|
| 808 |
+
tag_name, category = self.dataset.get_tag_info(idx)
|
| 809 |
+
|
| 810 |
+
if category not in result:
|
| 811 |
+
result[category] = []
|
| 812 |
+
|
| 813 |
+
if include_probabilities:
|
| 814 |
+
prob = predictions[idx].item()
|
| 815 |
+
result[category].append((tag_name, prob))
|
| 816 |
+
else:
|
| 817 |
+
result[category].append(tag_name)
|
| 818 |
+
|
| 819 |
+
# Sort tags by probability within each category
|
| 820 |
+
if include_probabilities:
|
| 821 |
+
for category in result:
|
| 822 |
+
result[category] = sorted(result[category], key=lambda x: x[1], reverse=True)
|
| 823 |
+
|
| 824 |
+
return result
|
| 825 |
+
|
| 826 |
+
def load_model(model_dir, device='cuda'):
|
| 827 |
+
"""Load model with better error handling and warnings"""
|
| 828 |
+
print(f"Loading model from {model_dir}")
|
| 829 |
+
|
| 830 |
+
try:
|
| 831 |
+
# Load metadata
|
| 832 |
+
metadata_path = os.path.join(model_dir, "metadata.json")
|
| 833 |
+
if not os.path.exists(metadata_path):
|
| 834 |
+
raise FileNotFoundError(f"Metadata file not found at {metadata_path}")
|
| 835 |
+
|
| 836 |
+
with open(metadata_path, 'r') as f:
|
| 837 |
+
metadata = json.load(f)
|
| 838 |
+
|
| 839 |
+
# Load model info
|
| 840 |
+
model_info_path = os.path.join(model_dir, "model_info_initial_only.json")
|
| 841 |
+
if os.path.exists(model_info_path):
|
| 842 |
+
with open(model_info_path, 'r') as f:
|
| 843 |
+
model_info = json.load(f)
|
| 844 |
+
else:
|
| 845 |
+
print("WARNING: Model info file not found, using default settings")
|
| 846 |
+
model_info = {
|
| 847 |
+
"tag_context_size": 256,
|
| 848 |
+
"num_heads": 16,
|
| 849 |
+
"precision": "float16"
|
| 850 |
+
}
|
| 851 |
+
|
| 852 |
+
# Create dataset wrapper
|
| 853 |
+
dataset = TagDataset(
|
| 854 |
+
total_tags=metadata['total_tags'],
|
| 855 |
+
idx_to_tag=metadata['idx_to_tag'],
|
| 856 |
+
tag_to_category=metadata['tag_to_category']
|
| 857 |
+
)
|
| 858 |
+
|
| 859 |
+
# Initialize model with exact settings from model_info
|
| 860 |
+
model = ImageTagger(
|
| 861 |
+
total_tags=metadata['total_tags'],
|
| 862 |
+
dataset=dataset,
|
| 863 |
+
num_heads=model_info.get('num_heads', 16),
|
| 864 |
+
tag_context_size=model_info.get('tag_context_size', 256),
|
| 865 |
+
pretrained=False
|
| 866 |
+
)
|
| 867 |
+
|
| 868 |
+
# Load weights
|
| 869 |
+
state_dict_path = os.path.join(model_dir, "model.pt")
|
| 870 |
+
if not os.path.exists(state_dict_path):
|
| 871 |
+
raise FileNotFoundError(f"Model state dict not found at {state_dict_path}")
|
| 872 |
+
|
| 873 |
+
state_dict = torch.load(state_dict_path, map_location=device)
|
| 874 |
+
|
| 875 |
+
# First try strict loading
|
| 876 |
+
try:
|
| 877 |
+
model.load_state_dict(state_dict, strict=True)
|
| 878 |
+
print("✓ Model state dict loaded with strict=True successfully")
|
| 879 |
+
except Exception as e:
|
| 880 |
+
print(f"! Strict loading failed: {str(e)}")
|
| 881 |
+
print("Attempting non-strict loading...")
|
| 882 |
+
|
| 883 |
+
# Try non-strict loading
|
| 884 |
+
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
|
| 885 |
+
|
| 886 |
+
print(f"Non-strict loading completed with:")
|
| 887 |
+
print(f"- {len(missing_keys)} missing keys")
|
| 888 |
+
print(f"- {len(unexpected_keys)} unexpected keys")
|
| 889 |
+
|
| 890 |
+
if len(missing_keys) > 0:
|
| 891 |
+
print(f"Sample missing keys: {missing_keys[:5]}")
|
| 892 |
+
if len(unexpected_keys) > 0:
|
| 893 |
+
print(f"Sample unexpected keys: {unexpected_keys[:5]}")
|
| 894 |
+
|
| 895 |
+
# Move model to device
|
| 896 |
+
model = model.to(device)
|
| 897 |
+
|
| 898 |
+
# Set to half precision if needed
|
| 899 |
+
if model_info.get('precision') == 'float16':
|
| 900 |
+
model = model.half()
|
| 901 |
+
print("✓ Model converted to half precision")
|
| 902 |
+
|
| 903 |
+
# Set to eval mode
|
| 904 |
+
model.eval()
|
| 905 |
+
print("✓ Model set to evaluation mode")
|
| 906 |
+
|
| 907 |
+
# Verify parameter dtype
|
| 908 |
+
param_dtype = next(model.parameters()).dtype
|
| 909 |
+
print(f"✓ Model loaded with precision: {param_dtype}")
|
| 910 |
+
|
| 911 |
+
return model, dataset
|
| 912 |
+
|
| 913 |
+
except Exception as e:
|
| 914 |
+
print(f"ERROR loading model: {str(e)}")
|
| 915 |
+
import traceback
|
| 916 |
+
traceback.print_exc()
|
| 917 |
+
raise
|
| 918 |
+
|
| 919 |
+
# Example usage
|
| 920 |
+
if __name__ == "__main__":
|
| 921 |
+
import sys
|
| 922 |
+
|
| 923 |
+
# Get model directory from command line or use default
|
| 924 |
+
model_dir = sys.argv[1] if len(sys.argv) > 1 else "./exported_model"
|
| 925 |
+
|
| 926 |
+
# Load model
|
| 927 |
+
model, dataset, thresholds = load_model(model_dir)
|
| 928 |
+
|
| 929 |
+
# Display info
|
| 930 |
+
print(f"\nModel information:")
|
| 931 |
+
print(f" Total tags: {dataset.total_tags}")
|
| 932 |
+
print(f" Device: {next(model.parameters()).device}")
|
| 933 |
+
print(f" Precision: {next(model.parameters()).dtype}")
|
| 934 |
+
|
| 935 |
+
# Test on an image if provided
|
| 936 |
+
if len(sys.argv) > 2:
|
| 937 |
+
image_path = sys.argv[2]
|
| 938 |
+
print(f"\nRunning inference on {image_path}")
|
| 939 |
+
|
| 940 |
+
# Use category thresholds if available
|
| 941 |
+
if thresholds and 'categories' in thresholds:
|
| 942 |
+
category_thresholds = {cat: opt['balanced']['threshold']
|
| 943 |
+
for cat, opt in thresholds['categories'].items()}
|
| 944 |
+
results = model.predict(image_path, category_thresholds=category_thresholds)
|
| 945 |
+
else:
|
| 946 |
+
results = model.predict(image_path)
|
| 947 |
+
|
| 948 |
+
# Get tags
|
| 949 |
+
tags = model.get_tags_from_predictions(results['predictions'])
|
| 950 |
+
|
| 951 |
+
# Print tags by category
|
| 952 |
+
print("\nPredicted tags:")
|
| 953 |
+
for category, category_tags in tags.items():
|
| 954 |
+
print(f"\n{category.capitalize()}:")
|
| 955 |
+
for tag, prob in category_tags:
|
| 956 |
+
print(f" {tag}: {prob:.3f}")
|
model_config.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"class_name": "ImageTagger",
|
| 3 |
+
"args": {
|
| 4 |
+
"total_tags": 70527,
|
| 5 |
+
"num_heads": 16,
|
| 6 |
+
"dropout": 0.1,
|
| 7 |
+
"tag_context_size": 256
|
| 8 |
+
}
|
| 9 |
+
}
|
model_info_initial_only.json
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"precision": "float16",
|
| 3 |
+
"tag_context_size": 256,
|
| 4 |
+
"num_heads": 16,
|
| 5 |
+
"architecture": "ImageTagger",
|
| 6 |
+
"embedding_dim": 1280,
|
| 7 |
+
"backbone": "efficientnet_v2_l",
|
| 8 |
+
"model_type": "initial_only"
|
| 9 |
+
}
|
model_no_flash.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from torchvision.models import efficientnet_v2_l, EfficientNet_V2_L_Weights
|
| 5 |
+
|
| 6 |
+
class MultiheadAttentionNoFlash(nn.Module):
|
| 7 |
+
"""Custom multi-head attention module (replaces FlashAttention) using ONNX-friendly ops."""
|
| 8 |
+
def __init__(self, dim, num_heads=8, dropout=0.0):
|
| 9 |
+
super().__init__()
|
| 10 |
+
assert dim % num_heads == 0, "Embedding dim must be divisible by num_heads"
|
| 11 |
+
self.dim = dim
|
| 12 |
+
self.num_heads = num_heads
|
| 13 |
+
self.head_dim = dim // num_heads
|
| 14 |
+
self.scale = self.head_dim ** -0.5 # scaling factor for dot-product attention
|
| 15 |
+
|
| 16 |
+
# Define separate projections for query, key, value, and output (no biases to match FlashAttention)
|
| 17 |
+
self.q_proj = nn.Linear(dim, dim, bias=False)
|
| 18 |
+
self.k_proj = nn.Linear(dim, dim, bias=False)
|
| 19 |
+
self.v_proj = nn.Linear(dim, dim, bias=False)
|
| 20 |
+
self.out_proj = nn.Linear(dim, dim, bias=False)
|
| 21 |
+
# (Note: We omit dropout in attention computation for ONNX simplicity; model should be set to eval mode anyway.)
|
| 22 |
+
|
| 23 |
+
def forward(self, query, key=None, value=None):
|
| 24 |
+
# Allow usage as self-attention if key/value not provided
|
| 25 |
+
if key is None:
|
| 26 |
+
key = query
|
| 27 |
+
if value is None:
|
| 28 |
+
value = key
|
| 29 |
+
|
| 30 |
+
# Linear projections
|
| 31 |
+
Q = self.q_proj(query) # [B, S_q, dim]
|
| 32 |
+
K = self.k_proj(key) # [B, S_k, dim]
|
| 33 |
+
V = self.v_proj(value) # [B, S_v, dim]
|
| 34 |
+
|
| 35 |
+
# Reshape into (B, num_heads, S, head_dim) for computing attention per head
|
| 36 |
+
B, S_q, _ = Q.shape
|
| 37 |
+
_, S_k, _ = K.shape
|
| 38 |
+
Q = Q.view(B, S_q, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, S_q, head_dim]
|
| 39 |
+
K = K.view(B, S_k, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, S_k, head_dim]
|
| 40 |
+
V = V.view(B, S_k, self.num_heads, self.head_dim).transpose(1, 2) # [B, heads, S_k, head_dim]
|
| 41 |
+
|
| 42 |
+
# Scaled dot-product attention: compute attention weights
|
| 43 |
+
attn_weights = torch.matmul(Q, K.transpose(2, 3)) # [B, heads, S_q, S_k]
|
| 44 |
+
attn_weights = attn_weights * self.scale
|
| 45 |
+
attn_probs = F.softmax(attn_weights, dim=-1) # softmax over S_k (key length)
|
| 46 |
+
|
| 47 |
+
# Apply attention weights to values
|
| 48 |
+
attn_output = torch.matmul(attn_probs, V) # [B, heads, S_q, head_dim]
|
| 49 |
+
|
| 50 |
+
# Reshape back to [B, S_q, dim]
|
| 51 |
+
attn_output = attn_output.transpose(1, 2).contiguous().view(B, S_q, self.dim)
|
| 52 |
+
# Output projection
|
| 53 |
+
output = self.out_proj(attn_output) # [B, S_q, dim]
|
| 54 |
+
return output
|
| 55 |
+
|
| 56 |
+
class ImageTaggerRefinedONNX(nn.Module):
|
| 57 |
+
"""
|
| 58 |
+
Refined CAMIE Image Tagger model without FlashAttention.
|
| 59 |
+
- EfficientNetV2 backbone
|
| 60 |
+
- Initial classifier for preliminary tag logits
|
| 61 |
+
- Multi-head self-attention on top predicted tag embeddings
|
| 62 |
+
- Multi-head cross-attention between image feature and tag embeddings
|
| 63 |
+
- Refined classifier for final tag logits
|
| 64 |
+
"""
|
| 65 |
+
def __init__(self, total_tags, tag_context_size=256, num_heads=16, dropout=0.1):
|
| 66 |
+
super().__init__()
|
| 67 |
+
self.tag_context_size = tag_context_size
|
| 68 |
+
self.embedding_dim = 1280 # EfficientNetV2-L feature dimension
|
| 69 |
+
|
| 70 |
+
# Backbone feature extractor (EfficientNetV2-L)
|
| 71 |
+
backbone = efficientnet_v2_l(weights=EfficientNet_V2_L_Weights.DEFAULT)
|
| 72 |
+
backbone.classifier = nn.Identity() # remove final classification head
|
| 73 |
+
self.backbone = backbone
|
| 74 |
+
|
| 75 |
+
# Spatial pooling to get a single feature vector per image (1x1 avg pool)
|
| 76 |
+
self.spatial_pool = nn.AdaptiveAvgPool2d((1, 1))
|
| 77 |
+
|
| 78 |
+
# Initial classifier (two-layer MLP) to predict tags from image feature
|
| 79 |
+
self.initial_classifier = nn.Sequential(
|
| 80 |
+
nn.Linear(self.embedding_dim, self.embedding_dim * 2),
|
| 81 |
+
nn.LayerNorm(self.embedding_dim * 2),
|
| 82 |
+
nn.GELU(),
|
| 83 |
+
nn.Dropout(dropout),
|
| 84 |
+
nn.Linear(self.embedding_dim * 2, self.embedding_dim),
|
| 85 |
+
nn.LayerNorm(self.embedding_dim),
|
| 86 |
+
nn.GELU(),
|
| 87 |
+
nn.Linear(self.embedding_dim, total_tags) # outputs raw logits for all tags
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
# Embedding for tags (each tag gets an embedding vector, used for attention)
|
| 91 |
+
self.tag_embedding = nn.Embedding(total_tags, self.embedding_dim)
|
| 92 |
+
|
| 93 |
+
# Self-attention over the selected tag embeddings (replaces FlashAttention)
|
| 94 |
+
self.tag_attention = MultiheadAttentionNoFlash(self.embedding_dim, num_heads=num_heads, dropout=dropout)
|
| 95 |
+
self.tag_norm = nn.LayerNorm(self.embedding_dim)
|
| 96 |
+
|
| 97 |
+
# Projection from image feature to query vector for cross-attention
|
| 98 |
+
self.cross_proj = nn.Sequential(
|
| 99 |
+
nn.Linear(self.embedding_dim, self.embedding_dim * 2),
|
| 100 |
+
nn.LayerNorm(self.embedding_dim * 2),
|
| 101 |
+
nn.GELU(),
|
| 102 |
+
nn.Dropout(dropout),
|
| 103 |
+
nn.Linear(self.embedding_dim * 2, self.embedding_dim)
|
| 104 |
+
)
|
| 105 |
+
# Cross-attention between image feature (as query) and tag features (as key/value)
|
| 106 |
+
self.cross_attention = MultiheadAttentionNoFlash(self.embedding_dim, num_heads=num_heads, dropout=dropout)
|
| 107 |
+
self.cross_norm = nn.LayerNorm(self.embedding_dim)
|
| 108 |
+
|
| 109 |
+
# Refined classifier (takes concatenated original & attended features)
|
| 110 |
+
self.refined_classifier = nn.Sequential(
|
| 111 |
+
nn.Linear(self.embedding_dim * 2, self.embedding_dim * 2),
|
| 112 |
+
nn.LayerNorm(self.embedding_dim * 2),
|
| 113 |
+
nn.GELU(),
|
| 114 |
+
nn.Dropout(dropout),
|
| 115 |
+
nn.Linear(self.embedding_dim * 2, self.embedding_dim),
|
| 116 |
+
nn.LayerNorm(self.embedding_dim),
|
| 117 |
+
nn.GELU(),
|
| 118 |
+
nn.Linear(self.embedding_dim, total_tags)
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
# Temperature parameter for scaling logits (to calibrate confidence)
|
| 122 |
+
self.temperature = nn.Parameter(torch.ones(1) * 1.5)
|
| 123 |
+
|
| 124 |
+
def forward(self, images):
|
| 125 |
+
# 1. Feature extraction
|
| 126 |
+
feats = self.backbone.features(images) # [B, 1280, H/32, W/32] features
|
| 127 |
+
feats = self.spatial_pool(feats).squeeze(-1).squeeze(-1) # [B, 1280] global feature vector per image
|
| 128 |
+
|
| 129 |
+
# 2. Initial tag prediction
|
| 130 |
+
initial_logits = self.initial_classifier(feats) # [B, total_tags]
|
| 131 |
+
# Scale by temperature and clamp (to stabilize extreme values, as in original)
|
| 132 |
+
initial_preds = torch.clamp(initial_logits / self.temperature, min=-15.0, max=15.0)
|
| 133 |
+
|
| 134 |
+
# 3. Select top-k predicted tags for context (tag_context_size)
|
| 135 |
+
probs = torch.sigmoid(initial_preds) # convert logits to probabilities
|
| 136 |
+
# Get indices of top `tag_context_size` tags for each sample
|
| 137 |
+
_, topk_indices = torch.topk(probs, k=self.tag_context_size, dim=1)
|
| 138 |
+
# 4. Embed selected tags
|
| 139 |
+
tag_embeds = self.tag_embedding(topk_indices) # [B, tag_context_size, embedding_dim]
|
| 140 |
+
|
| 141 |
+
# 5. Self-attention on tag embeddings (to refine tag representation)
|
| 142 |
+
attn_tags = self.tag_attention(tag_embeds) # [B, tag_context_size, embedding_dim]
|
| 143 |
+
attn_tags = self.tag_norm(attn_tags) # layer norm
|
| 144 |
+
|
| 145 |
+
# 6. Cross-attention between image feature and attended tags
|
| 146 |
+
# Expand image features to have one per tag position
|
| 147 |
+
feat_q = self.cross_proj(feats) # [B, embedding_dim]
|
| 148 |
+
# Repeat each image feature vector tag_context_size times to form a sequence
|
| 149 |
+
feat_q = feat_q.unsqueeze(1).expand(-1, self.tag_context_size, -1) # [B, tag_context_size, embedding_dim]
|
| 150 |
+
# Use image features as queries, tag embeddings as keys and values
|
| 151 |
+
cross_attn = self.cross_attention(feat_q, attn_tags, attn_tags) # [B, tag_context_size, embedding_dim]
|
| 152 |
+
cross_attn = self.cross_norm(cross_attn)
|
| 153 |
+
|
| 154 |
+
# 7. Fuse features: average the cross-attended tag outputs, and combine with original features
|
| 155 |
+
fused_feature = cross_attn.mean(dim=1) # [B, embedding_dim]
|
| 156 |
+
combined = torch.cat([feats, fused_feature], dim=1) # [B, embedding_dim*2]
|
| 157 |
+
|
| 158 |
+
# 8. Refined tag prediction
|
| 159 |
+
refined_logits = self.refined_classifier(combined) # [B, total_tags]
|
| 160 |
+
refined_preds = torch.clamp(refined_logits / self.temperature, min=-15.0, max=15.0)
|
| 161 |
+
|
| 162 |
+
return initial_preds, refined_preds
|
| 163 |
+
|
| 164 |
+
# --- Load the pretrained refined model weights ---
|
| 165 |
+
total_tags = 70527 # total number of tags in the dataset (Danbooru 2024)
|
| 166 |
+
from safetensors.torch import load_file
|
| 167 |
+
safetensors_path = 'model_refined.safetensors'
|
| 168 |
+
state_dict = load_file(safetensors_path, device='cpu') # Load the saved weights (should be an OrderedDict)
|
| 169 |
+
#state_dict = torch.load("model_refined.pt", map_location="cpu") # Load the saved weights (should be an OrderedDict)
|
| 170 |
+
|
| 171 |
+
# Initialize our model and load weights
|
| 172 |
+
model = ImageTaggerRefinedONNX(total_tags=total_tags)
|
| 173 |
+
model.load_state_dict(state_dict)
|
| 174 |
+
model.eval() # set to evaluation mode (disable dropout)
|
| 175 |
+
|
| 176 |
+
# (Optional) Cast to float32 if weights were in half precision
|
| 177 |
+
# model = model.float()
|
| 178 |
+
|
| 179 |
+
# --- Export to ONNX ---
|
| 180 |
+
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=False) # dummy batch of 1 image (3x512x512)
|
| 181 |
+
output_onnx_file = "camie_refined_no_flash_v15.onnx"
|
| 182 |
+
torch.onnx.export(
|
| 183 |
+
model, dummy_input, output_onnx_file,
|
| 184 |
+
export_params=True, # store trained parameter weights inside the model file
|
| 185 |
+
opset_version=17, # ONNX opset version (ensure support for needed ops)
|
| 186 |
+
do_constant_folding=True, # optimize constant expressions
|
| 187 |
+
input_names=["image"],
|
| 188 |
+
output_names=["initial_tags", "refined_tags"],
|
| 189 |
+
dynamic_axes={ # set batch dimension to be dynamic
|
| 190 |
+
"image": {0: "batch"},
|
| 191 |
+
"initial_tags": {0: "batch"},
|
| 192 |
+
"refined_tags": {0: "batch"}
|
| 193 |
+
}
|
| 194 |
+
)
|
| 195 |
+
print(f"ONNX model exported to {output_onnx_file}")
|
thresholds.json
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"overall": {
|
| 3 |
+
"balanced": {
|
| 4 |
+
"threshold": 0.3285714089870453,
|
| 5 |
+
"f1": 0.6128875755303665,
|
| 6 |
+
"precision": 0.6348684210526315,
|
| 7 |
+
"recall": 0.5923778668258164
|
| 8 |
+
},
|
| 9 |
+
"high_precision": {
|
| 10 |
+
"threshold": 0.48367345333099365,
|
| 11 |
+
"f1": 0.5073781135639239,
|
| 12 |
+
"precision": 0.8244772683675426,
|
| 13 |
+
"recall": 0.3664421519311109
|
| 14 |
+
},
|
| 15 |
+
"high_recall": {
|
| 16 |
+
"threshold": 0.20612245798110962,
|
| 17 |
+
"f1": 0.5140483341286104,
|
| 18 |
+
"precision": 0.38317013976064945,
|
| 19 |
+
"recall": 0.7807144684116293
|
| 20 |
+
}
|
| 21 |
+
},
|
| 22 |
+
"weighted": {
|
| 23 |
+
"f1": {
|
| 24 |
+
"threshold": 0.31224489212036133,
|
| 25 |
+
"value": 0.666115043816508
|
| 26 |
+
}
|
| 27 |
+
},
|
| 28 |
+
"categories": {
|
| 29 |
+
"copyright": {
|
| 30 |
+
"balanced": {
|
| 31 |
+
"threshold": 0.3857142925262451,
|
| 32 |
+
"f1": 0.7885196374622356,
|
| 33 |
+
"precision": 0.903114186851211,
|
| 34 |
+
"recall": 0.6997319034852547
|
| 35 |
+
},
|
| 36 |
+
"high_precision": {
|
| 37 |
+
"threshold": 0.5,
|
| 38 |
+
"f1": 0.7524429967426711,
|
| 39 |
+
"precision": 0.9585062240663901,
|
| 40 |
+
"recall": 0.6193029490616622
|
| 41 |
+
},
|
| 42 |
+
"high_recall": {
|
| 43 |
+
"threshold": 0.13265305757522583,
|
| 44 |
+
"f1": 0.5149136577708007,
|
| 45 |
+
"precision": 0.36403995560488345,
|
| 46 |
+
"recall": 0.8793565683646113
|
| 47 |
+
}
|
| 48 |
+
},
|
| 49 |
+
"character": {
|
| 50 |
+
"balanced": {
|
| 51 |
+
"threshold": 0.30408161878585815,
|
| 52 |
+
"f1": 0.769028871391076,
|
| 53 |
+
"precision": 0.8878787878787879,
|
| 54 |
+
"recall": 0.6782407407407407
|
| 55 |
+
},
|
| 56 |
+
"high_precision": {
|
| 57 |
+
"threshold": 0.47551020979881287,
|
| 58 |
+
"f1": 0.7128129602356407,
|
| 59 |
+
"precision": 0.979757085020243,
|
| 60 |
+
"recall": 0.5601851851851852
|
| 61 |
+
},
|
| 62 |
+
"high_recall": {
|
| 63 |
+
"threshold": 0.13265305757522583,
|
| 64 |
+
"f1": 0.5132616487455197,
|
| 65 |
+
"precision": 0.37175493250259606,
|
| 66 |
+
"recall": 0.8287037037037037
|
| 67 |
+
}
|
| 68 |
+
},
|
| 69 |
+
"general": {
|
| 70 |
+
"balanced": {
|
| 71 |
+
"threshold": 0.3285714089870453,
|
| 72 |
+
"f1": 0.6070014256296532,
|
| 73 |
+
"precision": 0.6206003023105161,
|
| 74 |
+
"recall": 0.5939857393820399
|
| 75 |
+
},
|
| 76 |
+
"high_precision": {
|
| 77 |
+
"threshold": 0.47551020979881287,
|
| 78 |
+
"f1": 0.5074963046385584,
|
| 79 |
+
"precision": 0.7958057395143487,
|
| 80 |
+
"recall": 0.3725328097550894
|
| 81 |
+
},
|
| 82 |
+
"high_recall": {
|
| 83 |
+
"threshold": 0.20612245798110962,
|
| 84 |
+
"f1": 0.5094889521485699,
|
| 85 |
+
"precision": 0.3790529978316777,
|
| 86 |
+
"recall": 0.7767903275808619
|
| 87 |
+
}
|
| 88 |
+
},
|
| 89 |
+
"meta": {
|
| 90 |
+
"balanced": {
|
| 91 |
+
"threshold": 0.31224489212036133,
|
| 92 |
+
"f1": 0.5943152454780362,
|
| 93 |
+
"precision": 0.5948275862068966,
|
| 94 |
+
"recall": 0.5938037865748709
|
| 95 |
+
},
|
| 96 |
+
"high_precision": {
|
| 97 |
+
"threshold": 0.41020408272743225,
|
| 98 |
+
"f1": 0.5087924970691676,
|
| 99 |
+
"precision": 0.7977941176470589,
|
| 100 |
+
"recall": 0.37349397590361444
|
| 101 |
+
},
|
| 102 |
+
"high_recall": {
|
| 103 |
+
"threshold": 0.22244898974895477,
|
| 104 |
+
"f1": 0.5037433155080214,
|
| 105 |
+
"precision": 0.365399534522886,
|
| 106 |
+
"recall": 0.810671256454389
|
| 107 |
+
}
|
| 108 |
+
},
|
| 109 |
+
"rating": {
|
| 110 |
+
"balanced": {
|
| 111 |
+
"threshold": 0.34489795565605164,
|
| 112 |
+
"f1": 0.7964912280701754,
|
| 113 |
+
"precision": 0.7229299363057324,
|
| 114 |
+
"recall": 0.88671875
|
| 115 |
+
},
|
| 116 |
+
"high_precision": {
|
| 117 |
+
"threshold": 0.5,
|
| 118 |
+
"f1": 0.6966824644549763,
|
| 119 |
+
"precision": 0.8855421686746988,
|
| 120 |
+
"recall": 0.57421875
|
| 121 |
+
},
|
| 122 |
+
"high_recall": {
|
| 123 |
+
"threshold": 0.10000000149011612,
|
| 124 |
+
"f1": 0.6538952745849297,
|
| 125 |
+
"precision": 0.4857685009487666,
|
| 126 |
+
"recall": 1.0
|
| 127 |
+
}
|
| 128 |
+
},
|
| 129 |
+
"artist": {
|
| 130 |
+
"balanced": {
|
| 131 |
+
"threshold": 0.22244898974895477,
|
| 132 |
+
"f1": 0.5017921146953405,
|
| 133 |
+
"precision": 0.56,
|
| 134 |
+
"recall": 0.45454545454545453
|
| 135 |
+
},
|
| 136 |
+
"high_precision": {
|
| 137 |
+
"threshold": 0.22244898974895477,
|
| 138 |
+
"f1": 0.5017921146953405,
|
| 139 |
+
"precision": 0.56,
|
| 140 |
+
"recall": 0.45454545454545453
|
| 141 |
+
},
|
| 142 |
+
"high_recall": {
|
| 143 |
+
"threshold": 0.22244898974895477,
|
| 144 |
+
"f1": 0.5017921146953405,
|
| 145 |
+
"precision": 0.56,
|
| 146 |
+
"recall": 0.45454545454545453
|
| 147 |
+
}
|
| 148 |
+
},
|
| 149 |
+
"year": {
|
| 150 |
+
"balanced": {
|
| 151 |
+
"threshold": 0.2877551317214966,
|
| 152 |
+
"f1": 0.32867132867132864,
|
| 153 |
+
"precision": 0.2974683544303797,
|
| 154 |
+
"recall": 0.3671875
|
| 155 |
+
},
|
| 156 |
+
"high_precision": {
|
| 157 |
+
"threshold": 0,
|
| 158 |
+
"f1": 0,
|
| 159 |
+
"precision": 0,
|
| 160 |
+
"recall": 0
|
| 161 |
+
},
|
| 162 |
+
"high_recall": {
|
| 163 |
+
"threshold": 0,
|
| 164 |
+
"f1": 0,
|
| 165 |
+
"precision": 0,
|
| 166 |
+
"recall": 0
|
| 167 |
+
}
|
| 168 |
+
}
|
| 169 |
+
}
|
| 170 |
+
}
|