Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
Image Tagging Server using ONNX and FastAPI. | |
This script sets up a web server that provides endpoints for tagging images | |
using a pre-trained ONNX model. It supports single image processing, batch | |
processing, and can download model artifacts from a Hugging Face repository. | |
""" | |
import argparse | |
import logging | |
import math | |
import os | |
import pathlib | |
import time | |
import types | |
import typing | |
from contextlib import asynccontextmanager | |
from io import BytesIO | |
from pathlib import Path | |
from typing import Any, Dict, List | |
import huggingface_hub | |
import numpy as np | |
import pandas as pd | |
import timm | |
import torch | |
import uvicorn | |
from fastapi import FastAPI, File, HTTPException, UploadFile | |
from PIL import Image | |
from pydantic import BaseModel, Field | |
from pydantic_settings import BaseSettings | |
from timm.data import create_transform, resolve_data_config | |
from torch import nn | |
from torch.nn import functional as F | |
# --- Configuration Management --- | |
class Settings(BaseSettings): | |
"""Manages application configuration using Pydantic.""" | |
host: str = Field(default="0.0.0.0", description="Server host.") | |
port: int = Field(default=8080, description="Server port.") | |
instances: int = Field(default=1, description="Number of uvicorn workers.") | |
triton: int = Field(default=0, description="Enable triton / torch.compile()") | |
log_level: str = Field(default="INFO", description="Logging level.") | |
model_repo: str = Field( | |
default=None, description="HuggingFace repository for model files." | |
) | |
model_file: str = Field( | |
default="model.safetensors", description="ONNX model filename." | |
) | |
tags_file: str = Field( | |
default="selected_tags.csv", description="CSV file with tag names." | |
) | |
thresholds_file: str = Field( | |
default="thresholds.csv", description="CSV file with category thresholds." | |
) | |
backend: str = Field( | |
default="cpu", | |
description="Inference backend ('cpu', 'cuda', 'tensorrt').", | |
pattern="^(cpu|cuda|tensorrt)$", | |
) | |
token: str | None = Field(default=None, description="HuggingFace Token.") | |
class Config: | |
env_prefix = "TAGGER_" | |
# --- Logging Setup --- | |
class CustomFormatter(logging.Formatter): | |
"""A custom log formatter with colors for different log levels.""" | |
LEVEL_COLORS = { | |
logging.DEBUG: "\x1b[38;20m", # Grey | |
logging.INFO: "\x1b[32m", # Green | |
logging.WARNING: "\x1b[33;20m", # Yellow | |
logging.ERROR: "\x1b[31;20m", # Red | |
logging.CRITICAL: "\x1b[31;1m", # Bold Red | |
} | |
RESET = "\x1b[0m" | |
def format(self, record: logging.LogRecord) -> str: | |
color = self.LEVEL_COLORS.get(record.levelno, "") | |
record.levelprefix = f"{color}{record.levelname:<8}{self.RESET}" | |
return super().format(record) | |
def setup_logging(log_level: str): | |
"""Configures the root logger.""" | |
logger = logging.getLogger() | |
logger.setLevel(log_level) | |
handler = logging.StreamHandler() | |
handler.setFormatter(CustomFormatter("%(levelprefix)s | %(message)s")) | |
logger.handlers = [handler] | |
# Suppress verbose logs from other libraries | |
logging.getLogger("uvicorn").handlers = [] | |
logging.getLogger("uvicorn.access").handlers = [] | |
return logger | |
def pil_ensure_rgb(image: Image.Image) -> Image.Image: | |
if image.mode not in ["RGB", "RGBA"]: | |
image = ( | |
image.convert("RGBA") | |
if "transparency" in image.info | |
else image.convert("RGB") | |
) | |
if image.mode == "RGBA": | |
canvas = Image.new("RGBA", image.size, (255, 255, 255)) | |
canvas.alpha_composite(image) | |
image = canvas.convert("RGB") | |
return image | |
def pil_pad_square(image: Image.Image) -> Image.Image: | |
w, h = image.size | |
px = max(w, h) | |
canvas = Image.new("RGB", (px, px), (255, 255, 255)) | |
canvas.paste(image, ((px - w) // 2, (px - h) // 2)) | |
return canvas | |
logger = setup_logging("DEBUG") | |
# --- API Models (Pydantic) --- | |
class Timing(BaseModel): | |
total_seconds: float | |
processing_seconds: float | |
TAG_RESPONSE = dict[str, list[dict[str, Any]]] | |
class TaggingResponse(BaseModel): | |
tags: TAG_RESPONSE | |
timing: Timing | |
class BatchTaggingResponse(BaseModel): | |
batch_size: int | |
results: list[TAG_RESPONSE] | |
timing: Timing | |
class StatusResponse(BaseModel): | |
status: str | |
model_name: str | None | |
class TaggerArgs(BaseModel): | |
tags_threshold: bool = False | |
# --- Core Logic: Tags & Tagger Classes --- | |
class Tags: | |
"""Handles loading and processing of tag data and thresholds.""" | |
DEFAULT_CATEGORIES = { | |
0: {"name": "general", "threshold": 0.35}, | |
4: {"name": "character", "threshold": 0.85}, | |
9: {"name": "rating", "threshold": 0.0}, | |
} | |
def __init__(self, labels_path: Path, threshold_path: Path | None = None): | |
logger.info(f"Loading labels from '{labels_path}'...") | |
start_time = time.time() | |
tags_df = pd.read_csv(labels_path) | |
self.tag_names = tags_df["name"].tolist() | |
self.tag_names_ndarray = np.array(self.tag_names) | |
self.categories: Dict[int, Dict[str, Any]] = {} | |
if "best_threshold" in tags_df: | |
self.tag_thresholds = np.array(tags_df["best_threshold"].tolist()) | |
else: | |
self.tag_thresholds = None | |
if ( | |
threshold_path | |
and threshold_path.is_file() | |
and threshold_path.stat().st_size > 0 | |
): | |
logger.info(f"Loading thresholds from '{threshold_path}'.") | |
for item in pd.read_csv(threshold_path).to_dict("records"): | |
if item["category"] not in self.categories: | |
self.categories[item["category"]] = { | |
"name": item["name"], | |
"threshold": item["threshold"], | |
} | |
else: | |
logger.info("No valid threshold file found. Using default categories.") | |
self.categories = self.DEFAULT_CATEGORIES | |
for cat_id, cat_info in self.categories.items(): | |
cat_info["indices"] = list(np.where(tags_df["category"] == cat_id)[0]) | |
logger.info( | |
f"Loaded {len(self.tag_names)} tags and {len(self.categories)} categories in {time.time() - start_time:.2f}s." | |
) | |
def process_predictions( | |
self, | |
preds: np.ndarray, | |
tag_indices: List[int], | |
threshold: float, | |
tags_threshold: bool = False, | |
) -> List[List[dict[str, Any]]]: | |
"""Filters and sorts predictions based on a threshold.""" | |
tag_names = self.tag_names_ndarray | |
# preds = np.asarray(preds) | |
tag_scores = preds[:, tag_indices] | |
tag_names_sel = tag_names[tag_indices] | |
if tags_threshold and self.tag_thresholds is not None: | |
mask = tag_scores > self.tag_thresholds[tag_indices] | |
tag_scores = np.where(mask, tag_scores, -np.inf) | |
else: | |
if threshold is not None: | |
mask = tag_scores > threshold | |
tag_scores = np.where(mask, tag_scores, -np.inf) | |
sorted_idx = np.argsort(-tag_scores, axis=1) | |
sorted_names = tag_names_sel[sorted_idx] | |
sorted_scores = np.take_along_axis(tag_scores, sorted_idx, axis=1) | |
return [ | |
[ | |
{"name": name, "confidence": float(score)} | |
for name, score in zip(names, scores) | |
if not math.isinf(float(score)) | |
] | |
for names, scores in zip(sorted_names, sorted_scores) | |
] | |
def resolve_batch_probs( | |
self, probs: np.ndarray, tags_threshold: bool = False | |
) -> list[dict[str, list[dict[str, Any]]]]: | |
"""Resolves raw probabilities into categorized tag predictions.""" | |
logger.info(f"Shapery: {probs.shape[0]}") | |
results_batched: dict[str, Any] = { | |
cat_info["name"]: [] for cat_info in self.categories.values() | |
} | |
for cat_info in self.categories.values(): | |
for _, result in enumerate( | |
self.process_predictions( | |
probs, | |
cat_info["indices"], | |
cat_info["threshold"], | |
tags_threshold=tags_threshold, | |
) | |
): | |
# {k: [dic[k] for dic in LD] for k in LD[0]} | |
results_batched[cat_info["name"]].append(result) | |
results_list = [ | |
dict(zip(results_batched, t)) for t in zip(*results_batched.values()) | |
] | |
return results_list | |
class Tagger: | |
"""Manages the ONNX model, image preprocessing, and inference.""" | |
def __init__( | |
self, | |
model_repo: str, | |
tags: Tags, | |
backend: str = "cpu", | |
instances: int = 1, | |
triton: bool = False, | |
): | |
self.tags_data = tags | |
self.model_repo = model_repo | |
self.device = torch.device( | |
"cuda" if backend == "cuda" and torch.cuda.is_available() else "cpu" | |
) | |
logger.info(f"Loading model from HuggingFace repo: {model_repo}...") | |
self.model: nn.Module = timm.create_model( | |
"hf-hub:" + model_repo, pretrained=False | |
) | |
self.swap_colorspace = False | |
if model_repo.startswith("animetimm/"): | |
logger.warning("Detected animetimm model. Enabling color swap.") | |
self.swap_colorspace = True | |
state_dict = timm.models.load_state_dict_from_hf(model_repo) | |
self.model.load_state_dict(state_dict) | |
self.model = self.model.eval().to(self.device) | |
if triton: | |
self.model.compile( | |
fullgraph=True, | |
) | |
self.transform = create_transform( | |
**resolve_data_config(self.model.pretrained_cfg, model=self.model) | |
) | |
self.model = nn.DataParallel(self.model, device_ids=list(range(instances))) | |
logger.info("Model loaded and ready.") | |
def _create_model( | |
self, model_repo: str, backend: str, index: int | |
) -> torch.nn.Module: | |
"""Creates and validates the ONNX Runtime inference session.""" | |
model: torch.nn.Module = timm.create_model( | |
"hf-hub:" + model_repo, pretrained=False | |
) | |
state_dict = timm.models.load_state_dict_from_hf(model_repo) | |
model.load_state_dict(state_dict) | |
model = model.eval() | |
if backend == "cuda": | |
model = model.to(torch.device(backend, index), dtype=torch.float32) | |
# model.compile( | |
# fullgraph=True, | |
# ) | |
return model | |
def preprocess_batch(self, image_batch: np.ndarray) -> torch.Tensor: | |
"""Converts NHWC float32 [0-1] NumPy images to a PyTorch tensor in NCHW RGB format.""" | |
pil_images = [ | |
Image.fromarray((img * 255).astype(np.uint8)) for img in image_batch | |
] | |
images = [pil_pad_square(pil_ensure_rgb(im)) for im in pil_images] | |
tensors = [self.transform(im) for im in images] | |
batch = torch.stack(tensors, dim=0) | |
if self.swap_colorspace: | |
print(batch.shape) | |
batch = batch[:, [2, 1, 0], :, :] | |
return batch.to(self.device) | |
def predict_batch( | |
self, image_batch: np.ndarray, tags_threshold=False | |
) -> List[dict[str, list[dict[str, Any]]]]: | |
batch_tensor = self.preprocess_batch(image_batch) | |
with ( | |
torch.inference_mode(), | |
torch.autocast(device_type="cuda", dtype=torch.bfloat16), | |
): | |
logits = self.model(batch_tensor) | |
probs = F.sigmoid(logits).cpu().to(torch.float32).numpy() | |
resolved = self.tags_data.resolve_batch_probs( | |
probs, tags_threshold=tags_threshold | |
) | |
return resolved | |
# --- FastAPI Application Setup --- | |
class AppState: | |
"""Container for application state, like the tagger instance.""" | |
def __init__(self, settings: Settings): | |
self.settings = settings | |
self.tagger: Tagger | None = None | |
def download_file(repo: str, filename: str, output_path: Path): | |
"""Downloads a file from Hugging Face Hub if it doesn't exist.""" | |
if not output_path.exists(): | |
logger.info(f"Downloading '{filename}' from repo '{repo}'...") | |
try: | |
path = huggingface_hub.hf_hub_download( | |
repo, | |
filename, | |
local_dir=output_path.parent, | |
local_dir_use_symlinks=False, | |
) | |
# Ensure the downloaded file is at the expected path | |
if Path(path) != output_path: | |
os.rename(path, output_path) | |
except Exception as e: | |
raise FileNotFoundError( | |
f"Failed to download '{filename}' from '{repo}': {e}" | |
) from e | |
async def lifespan(app: FastAPI): | |
"""Initializes the Tagger on startup and handles cleanup.""" | |
settings: Settings = app.state.settings | |
model_dir = Path("models") | |
model_dir.mkdir(exist_ok=True) | |
if settings.model_repo and pathlib.Path(settings.model_repo).is_dir(): | |
model_dir = pathlib.Path(settings.model_repo) | |
elif settings.model_repo: | |
model_dir = model_dir / pathlib.Path(settings.model_repo) | |
logger.info(f"Using directory: {model_dir} for storage...") | |
tags_path = model_dir / settings.tags_file | |
thresholds_path = model_dir / settings.thresholds_file | |
if settings.model_repo and not pathlib.Path(settings.model_repo).is_dir(): | |
try: | |
download_file(settings.model_repo, settings.tags_file, tags_path) | |
# Thresholds file is optional, so don't fail if it's not there | |
try: | |
download_file( | |
settings.model_repo, settings.thresholds_file, thresholds_path | |
) | |
except FileNotFoundError: | |
logger.warning( | |
f"Optional thresholds file '{settings.thresholds_file}' not found in repo." | |
) | |
except FileNotFoundError as e: | |
logger.critical(f"Could not start server: {e}") | |
# Exit if critical files are missing | |
return | |
if not tags_path.is_file(): | |
logger.critical( | |
"Model or tags file not found, and no model repository was specified. Exiting." | |
) | |
return | |
try: | |
logger.info("Initializing tagger...") | |
tags = Tags(labels_path=tags_path, threshold_path=thresholds_path) | |
app.state.tagger = Tagger( | |
settings.model_repo, | |
tags, | |
settings.backend, | |
instances=settings.instances, | |
triton=True if settings.triton else False, | |
) | |
logger.info("Tagger initialized successfully. Server is ready.") | |
except (ValueError, RuntimeError) as e: | |
logger.critical(f"Failed to initialize tagger: {e}") | |
return | |
yield | |
# --- Cleanup --- | |
app.state.tagger = None | |
logger.info("Server shutting down.") | |
def create_app(settings: Settings) -> FastAPI: | |
"""Creates and configures the FastAPI application instance.""" | |
app = FastAPI( | |
title="Image Tagger API", | |
description="An API for tagging images using an ONNX model.", | |
version="1.0.1", # Incremented version | |
lifespan=lifespan, | |
) | |
app.state = AppState(settings) | |
return app | |
# --- Dependency for Endpoints --- | |
def get_tagger(app: FastAPI) -> Tagger: | |
"""A dependency that provides the initialized tagger instance.""" | |
if not app.state.tagger: | |
raise HTTPException( | |
status_code=503, | |
detail="Tagger is not initialized. The server may be starting up or has encountered an error.", | |
) | |
return app.state.tagger | |
# --- API Endpoints --- | |
def add_endpoints(app: FastAPI): | |
tagger_dependency = lambda: get_tagger(app) | |
async def tag_batch( | |
tags_threshold: TaggerArgs = TaggerArgs(), | |
file: UploadFile = File( | |
..., description="A .npz file containing a batch of images in NHWC format." | |
), | |
): | |
if not file.filename or not file.filename.endswith(".npz"): | |
raise HTTPException( | |
status_code=400, | |
detail="Only .npz files are supported for batch processing.", | |
) | |
start_time = time.time() | |
tagger = tagger_dependency() | |
logger.info(f"Processing batch file: {file.filename}") | |
contents = await file.read() | |
with np.load(BytesIO(contents)) as npz: | |
batch = npz[npz.files[0]] | |
logger.info(f"Loaded batch of shape: {batch.shape}") | |
process_start = time.time() | |
try: | |
results = tagger.predict_batch(batch, tags_threshold=tags_threshold) | |
except ValueError as e: | |
raise HTTPException(status_code=400, detail=str(e)) | |
processing_time = time.time() - process_start | |
logger.info(f"Processed batch in {processing_time:.2f}s") | |
return BatchTaggingResponse( | |
batch_size=len(results), | |
results=results, | |
timing=Timing( | |
total_seconds=time.time() - start_time, | |
processing_seconds=processing_time, | |
), | |
) | |
async def status(): | |
tagger = tagger_dependency() | |
return StatusResponse( | |
status="ok", | |
model_name=tagger.model_repo, | |
) | |
def determine_type(field_type: type): | |
if type(field_type) is types.UnionType: | |
return typing.get_args(field_type)[0] | |
return field_type | |
# --- Main Execution --- | |
def main(): | |
"""Parses arguments, sets up the app, and runs the server.""" | |
parser = argparse.ArgumentParser(description="Image Tagging Server") | |
# Add arguments that correspond to the Settings fields | |
for field_name, field in Settings.model_fields.items(): | |
parser.add_argument( | |
f"--{field_name.replace('_', '-')}", | |
type=determine_type(field.annotation), # Basic type handling for argparse | |
default=field.default, | |
help=field.description, | |
) | |
args = parser.parse_args() | |
# Create settings from a combination of args, env vars, and defaults | |
settings = Settings(**vars(args)) | |
global logger | |
logger = setup_logging(settings.log_level.upper()) | |
if settings.token: | |
import os | |
logger.info("Using custom token...") | |
os.environ["HF_TOKEN"] = settings.token | |
app = create_app(settings) | |
add_endpoints(app) | |
uvicorn.run( | |
app, | |
host=settings.host, | |
port=settings.port, | |
log_config=None, # Use our custom logger | |
) | |
if __name__ == "__main__": | |
main() | |