Shio-Koube commited on
Commit
f54535f
·
verified ·
1 Parent(s): d474dcf

Create server.py

Browse files
Files changed (1) hide show
  1. server.py +562 -0
server.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Image Tagging Server using ONNX and FastAPI.
4
+
5
+ This script sets up a web server that provides endpoints for tagging images
6
+ using a pre-trained ONNX model. It supports single image processing, batch
7
+ processing, and can download model artifacts from a Hugging Face repository.
8
+ """
9
+
10
+ import argparse
11
+ import logging
12
+ import math
13
+ import os
14
+ import pathlib
15
+ import time
16
+ import types
17
+ import typing
18
+ from contextlib import asynccontextmanager
19
+ from io import BytesIO
20
+ from pathlib import Path
21
+ from typing import Any, Dict, List
22
+
23
+ import huggingface_hub
24
+ import numpy as np
25
+ import pandas as pd
26
+ import timm
27
+ import torch
28
+ import uvicorn
29
+ from fastapi import FastAPI, File, HTTPException, UploadFile
30
+ from PIL import Image
31
+ from pydantic import BaseModel, Field
32
+ from pydantic_settings import BaseSettings
33
+ from timm.data import create_transform, resolve_data_config
34
+ from torch import nn
35
+ from torch.nn import functional as F
36
+
37
+
38
+ # --- Configuration Management ---
39
+ class Settings(BaseSettings):
40
+ """Manages application configuration using Pydantic."""
41
+
42
+ host: str = Field(default="0.0.0.0", description="Server host.")
43
+ port: int = Field(default=8080, description="Server port.")
44
+ instances: int = Field(default=1, description="Number of uvicorn workers.")
45
+ triton: int = Field(default=0, description="Enable triton / torch.compile()")
46
+ log_level: str = Field(default="INFO", description="Logging level.")
47
+
48
+ model_repo: str = Field(
49
+ default=None, description="HuggingFace repository for model files."
50
+ )
51
+ model_file: str = Field(
52
+ default="model.safetensors", description="ONNX model filename."
53
+ )
54
+ tags_file: str = Field(
55
+ default="selected_tags.csv", description="CSV file with tag names."
56
+ )
57
+ thresholds_file: str = Field(
58
+ default="thresholds.csv", description="CSV file with category thresholds."
59
+ )
60
+ backend: str = Field(
61
+ default="cpu",
62
+ description="Inference backend ('cpu', 'cuda', 'tensorrt').",
63
+ pattern="^(cpu|cuda|tensorrt)$",
64
+ )
65
+ token: str | None = Field(default=None, description="HuggingFace Token.")
66
+
67
+ class Config:
68
+ env_prefix = "TAGGER_"
69
+
70
+
71
+ # --- Logging Setup ---
72
+ class CustomFormatter(logging.Formatter):
73
+ """A custom log formatter with colors for different log levels."""
74
+
75
+ LEVEL_COLORS = {
76
+ logging.DEBUG: "\x1b[38;20m", # Grey
77
+ logging.INFO: "\x1b[32m", # Green
78
+ logging.WARNING: "\x1b[33;20m", # Yellow
79
+ logging.ERROR: "\x1b[31;20m", # Red
80
+ logging.CRITICAL: "\x1b[31;1m", # Bold Red
81
+ }
82
+ RESET = "\x1b[0m"
83
+
84
+ def format(self, record: logging.LogRecord) -> str:
85
+ color = self.LEVEL_COLORS.get(record.levelno, "")
86
+ record.levelprefix = f"{color}{record.levelname:<8}{self.RESET}"
87
+ return super().format(record)
88
+
89
+
90
+ def setup_logging(log_level: str):
91
+ """Configures the root logger."""
92
+ logger = logging.getLogger()
93
+ logger.setLevel(log_level)
94
+ handler = logging.StreamHandler()
95
+ handler.setFormatter(CustomFormatter("%(levelprefix)s | %(message)s"))
96
+ logger.handlers = [handler]
97
+ # Suppress verbose logs from other libraries
98
+ logging.getLogger("uvicorn").handlers = []
99
+ logging.getLogger("uvicorn.access").handlers = []
100
+ return logger
101
+
102
+
103
+ def pil_ensure_rgb(image: Image.Image) -> Image.Image:
104
+ if image.mode not in ["RGB", "RGBA"]:
105
+ image = (
106
+ image.convert("RGBA")
107
+ if "transparency" in image.info
108
+ else image.convert("RGB")
109
+ )
110
+ if image.mode == "RGBA":
111
+ canvas = Image.new("RGBA", image.size, (255, 255, 255))
112
+ canvas.alpha_composite(image)
113
+ image = canvas.convert("RGB")
114
+ return image
115
+
116
+
117
+ def pil_pad_square(image: Image.Image) -> Image.Image:
118
+ w, h = image.size
119
+ px = max(w, h)
120
+ canvas = Image.new("RGB", (px, px), (255, 255, 255))
121
+ canvas.paste(image, ((px - w) // 2, (px - h) // 2))
122
+ return canvas
123
+
124
+
125
+ logger = setup_logging("DEBUG")
126
+
127
+
128
+ # --- API Models (Pydantic) ---
129
+ class Timing(BaseModel):
130
+ total_seconds: float
131
+ processing_seconds: float
132
+
133
+
134
+ TAG_RESPONSE = dict[str, list[dict[str, Any]]]
135
+
136
+
137
+ class TaggingResponse(BaseModel):
138
+ tags: TAG_RESPONSE
139
+ timing: Timing
140
+
141
+
142
+ class BatchTaggingResponse(BaseModel):
143
+ batch_size: int
144
+ results: list[TAG_RESPONSE]
145
+ timing: Timing
146
+
147
+
148
+ class StatusResponse(BaseModel):
149
+ status: str
150
+ model_name: str | None
151
+
152
+
153
+ class TaggerArgs(BaseModel):
154
+ tags_threshold: bool = False
155
+
156
+
157
+ # --- Core Logic: Tags & Tagger Classes ---
158
+ class Tags:
159
+ """Handles loading and processing of tag data and thresholds."""
160
+
161
+ DEFAULT_CATEGORIES = {
162
+ 0: {"name": "general", "threshold": 0.35},
163
+ 4: {"name": "character", "threshold": 0.85},
164
+ 9: {"name": "rating", "threshold": 0.0},
165
+ }
166
+
167
+ def __init__(self, labels_path: Path, threshold_path: Path | None = None):
168
+ logger.info(f"Loading labels from '{labels_path}'...")
169
+ start_time = time.time()
170
+
171
+ tags_df = pd.read_csv(labels_path)
172
+ self.tag_names = tags_df["name"].tolist()
173
+ self.tag_names_ndarray = np.array(self.tag_names)
174
+ self.categories: Dict[int, Dict[str, Any]] = {}
175
+
176
+ if "best_threshold" in tags_df:
177
+ self.tag_thresholds = np.array(tags_df["best_threshold"].tolist())
178
+ else:
179
+ self.tag_thresholds = None
180
+
181
+ if (
182
+ threshold_path
183
+ and threshold_path.is_file()
184
+ and threshold_path.stat().st_size > 0
185
+ ):
186
+ logger.info(f"Loading thresholds from '{threshold_path}'.")
187
+ for item in pd.read_csv(threshold_path).to_dict("records"):
188
+ if item["category"] not in self.categories:
189
+ self.categories[item["category"]] = {
190
+ "name": item["name"],
191
+ "threshold": item["threshold"],
192
+ }
193
+ else:
194
+ logger.info("No valid threshold file found. Using default categories.")
195
+ self.categories = self.DEFAULT_CATEGORIES
196
+
197
+ for cat_id, cat_info in self.categories.items():
198
+ cat_info["indices"] = list(np.where(tags_df["category"] == cat_id)[0])
199
+
200
+ logger.info(
201
+ f"Loaded {len(self.tag_names)} tags and {len(self.categories)} categories in {time.time() - start_time:.2f}s."
202
+ )
203
+
204
+ def process_predictions(
205
+ self,
206
+ preds: np.ndarray,
207
+ tag_indices: List[int],
208
+ threshold: float,
209
+ tags_threshold: bool = False,
210
+ ) -> List[List[dict[str, Any]]]:
211
+ """Filters and sorts predictions based on a threshold."""
212
+
213
+ tag_names = self.tag_names_ndarray
214
+ # preds = np.asarray(preds)
215
+ tag_scores = preds[:, tag_indices]
216
+ tag_names_sel = tag_names[tag_indices]
217
+
218
+ if tags_threshold and self.tag_thresholds is not None:
219
+ mask = tag_scores > self.tag_thresholds[tag_indices]
220
+ tag_scores = np.where(mask, tag_scores, -np.inf)
221
+ else:
222
+ if threshold is not None:
223
+ mask = tag_scores > threshold
224
+ tag_scores = np.where(mask, tag_scores, -np.inf)
225
+
226
+ sorted_idx = np.argsort(-tag_scores, axis=1)
227
+ sorted_names = tag_names_sel[sorted_idx]
228
+ sorted_scores = np.take_along_axis(tag_scores, sorted_idx, axis=1)
229
+
230
+ return [
231
+ [
232
+ {"name": name, "confidence": float(score)}
233
+ for name, score in zip(names, scores)
234
+ if not math.isinf(float(score))
235
+ ]
236
+ for names, scores in zip(sorted_names, sorted_scores)
237
+ ]
238
+
239
+ def resolve_batch_probs(
240
+ self, probs: np.ndarray, tags_threshold: bool = False
241
+ ) -> list[dict[str, list[dict[str, Any]]]]:
242
+ """Resolves raw probabilities into categorized tag predictions."""
243
+ logger.info(f"Shapery: {probs.shape[0]}")
244
+ results_batched: dict[str, Any] = {
245
+ cat_info["name"]: [] for cat_info in self.categories.values()
246
+ }
247
+ for cat_info in self.categories.values():
248
+ for _, result in enumerate(
249
+ self.process_predictions(
250
+ probs,
251
+ cat_info["indices"],
252
+ cat_info["threshold"],
253
+ tags_threshold=tags_threshold,
254
+ )
255
+ ):
256
+ # {k: [dic[k] for dic in LD] for k in LD[0]}
257
+ results_batched[cat_info["name"]].append(result)
258
+ results_list = [
259
+ dict(zip(results_batched, t)) for t in zip(*results_batched.values())
260
+ ]
261
+ return results_list
262
+
263
+
264
+ class Tagger:
265
+ """Manages the ONNX model, image preprocessing, and inference."""
266
+
267
+ def __init__(
268
+ self,
269
+ model_repo: str,
270
+ tags: Tags,
271
+ backend: str = "cpu",
272
+ instances: int = 1,
273
+ triton: bool = False,
274
+ ):
275
+ self.tags_data = tags
276
+ self.model_repo = model_repo
277
+ self.device = torch.device(
278
+ "cuda" if backend == "cuda" and torch.cuda.is_available() else "cpu"
279
+ )
280
+
281
+ logger.info(f"Loading model from HuggingFace repo: {model_repo}...")
282
+ self.model: nn.Module = timm.create_model(
283
+ "hf-hub:" + model_repo, pretrained=False
284
+ )
285
+ self.swap_colorspace = False
286
+ if model_repo.startswith("animetimm/"):
287
+ logger.warning("Detected animetimm model. Enabling color swap.")
288
+ self.swap_colorspace = True
289
+
290
+ state_dict = timm.models.load_state_dict_from_hf(model_repo)
291
+ self.model.load_state_dict(state_dict)
292
+ self.model = self.model.eval().to(self.device)
293
+ if triton:
294
+ self.model.compile(
295
+ fullgraph=True,
296
+ )
297
+ self.transform = create_transform(
298
+ **resolve_data_config(self.model.pretrained_cfg, model=self.model)
299
+ )
300
+ self.model = nn.DataParallel(self.model, device_ids=list(range(instances)))
301
+
302
+ logger.info("Model loaded and ready.")
303
+
304
+ def _create_model(
305
+ self, model_repo: str, backend: str, index: int
306
+ ) -> torch.nn.Module:
307
+ """Creates and validates the ONNX Runtime inference session."""
308
+ model: torch.nn.Module = timm.create_model(
309
+ "hf-hub:" + model_repo, pretrained=False
310
+ )
311
+ state_dict = timm.models.load_state_dict_from_hf(model_repo)
312
+ model.load_state_dict(state_dict)
313
+ model = model.eval()
314
+ if backend == "cuda":
315
+ model = model.to(torch.device(backend, index), dtype=torch.float32)
316
+ # model.compile(
317
+ # fullgraph=True,
318
+ # )
319
+ return model
320
+
321
+ def preprocess_batch(self, image_batch: np.ndarray) -> torch.Tensor:
322
+ """Converts NHWC float32 [0-1] NumPy images to a PyTorch tensor in NCHW RGB format."""
323
+ pil_images = [
324
+ Image.fromarray((img * 255).astype(np.uint8)) for img in image_batch
325
+ ]
326
+ images = [pil_pad_square(pil_ensure_rgb(im)) for im in pil_images]
327
+ tensors = [self.transform(im) for im in images]
328
+ batch = torch.stack(tensors, dim=0)
329
+
330
+ if self.swap_colorspace:
331
+ print(batch.shape)
332
+ batch = batch[:, [2, 1, 0], :, :]
333
+ return batch.to(self.device)
334
+
335
+ def predict_batch(
336
+ self, image_batch: np.ndarray, tags_threshold=False
337
+ ) -> List[dict[str, list[dict[str, Any]]]]:
338
+ batch_tensor = self.preprocess_batch(image_batch)
339
+
340
+ with (
341
+ torch.inference_mode(),
342
+ torch.autocast(device_type="cuda", dtype=torch.bfloat16),
343
+ ):
344
+ logits = self.model(batch_tensor)
345
+ probs = F.sigmoid(logits).cpu().to(torch.float32).numpy()
346
+
347
+ resolved = self.tags_data.resolve_batch_probs(
348
+ probs, tags_threshold=tags_threshold
349
+ )
350
+ return resolved
351
+
352
+
353
+ # --- FastAPI Application Setup ---
354
+ class AppState:
355
+ """Container for application state, like the tagger instance."""
356
+
357
+ def __init__(self, settings: Settings):
358
+ self.settings = settings
359
+ self.tagger: Tagger | None = None
360
+
361
+
362
+ def download_file(repo: str, filename: str, output_path: Path):
363
+ """Downloads a file from Hugging Face Hub if it doesn't exist."""
364
+ if not output_path.exists():
365
+ logger.info(f"Downloading '{filename}' from repo '{repo}'...")
366
+ try:
367
+ path = huggingface_hub.hf_hub_download(
368
+ repo,
369
+ filename,
370
+ local_dir=output_path.parent,
371
+ local_dir_use_symlinks=False,
372
+ )
373
+ # Ensure the downloaded file is at the expected path
374
+ if Path(path) != output_path:
375
+ os.rename(path, output_path)
376
+ except Exception as e:
377
+ raise FileNotFoundError(
378
+ f"Failed to download '{filename}' from '{repo}': {e}"
379
+ ) from e
380
+
381
+
382
+ @asynccontextmanager
383
+ async def lifespan(app: FastAPI):
384
+ """Initializes the Tagger on startup and handles cleanup."""
385
+ settings: Settings = app.state.settings
386
+
387
+ model_dir = Path("models")
388
+ model_dir.mkdir(exist_ok=True)
389
+
390
+ if settings.model_repo and pathlib.Path(settings.model_repo).is_dir():
391
+ model_dir = pathlib.Path(settings.model_repo)
392
+ elif settings.model_repo:
393
+ model_dir = model_dir / pathlib.Path(settings.model_repo)
394
+ logger.info(f"Using directory: {model_dir} for storage...")
395
+ tags_path = model_dir / settings.tags_file
396
+ thresholds_path = model_dir / settings.thresholds_file
397
+
398
+ if settings.model_repo and not pathlib.Path(settings.model_repo).is_dir():
399
+ try:
400
+ download_file(settings.model_repo, settings.tags_file, tags_path)
401
+ # Thresholds file is optional, so don't fail if it's not there
402
+ try:
403
+ download_file(
404
+ settings.model_repo, settings.thresholds_file, thresholds_path
405
+ )
406
+ except FileNotFoundError:
407
+ logger.warning(
408
+ f"Optional thresholds file '{settings.thresholds_file}' not found in repo."
409
+ )
410
+ except FileNotFoundError as e:
411
+ logger.critical(f"Could not start server: {e}")
412
+ # Exit if critical files are missing
413
+ return
414
+
415
+ if not tags_path.is_file():
416
+ logger.critical(
417
+ "Model or tags file not found, and no model repository was specified. Exiting."
418
+ )
419
+ return
420
+
421
+ try:
422
+ logger.info("Initializing tagger...")
423
+ tags = Tags(labels_path=tags_path, threshold_path=thresholds_path)
424
+ app.state.tagger = Tagger(
425
+ settings.model_repo,
426
+ tags,
427
+ settings.backend,
428
+ instances=settings.instances,
429
+ triton=True if settings.triton else False,
430
+ )
431
+ logger.info("Tagger initialized successfully. Server is ready.")
432
+ except (ValueError, RuntimeError) as e:
433
+ logger.critical(f"Failed to initialize tagger: {e}")
434
+ return
435
+
436
+ yield
437
+
438
+ # --- Cleanup ---
439
+ app.state.tagger = None
440
+ logger.info("Server shutting down.")
441
+
442
+
443
+ def create_app(settings: Settings) -> FastAPI:
444
+ """Creates and configures the FastAPI application instance."""
445
+ app = FastAPI(
446
+ title="Image Tagger API",
447
+ description="An API for tagging images using an ONNX model.",
448
+ version="1.0.1", # Incremented version
449
+ lifespan=lifespan,
450
+ )
451
+ app.state = AppState(settings)
452
+ return app
453
+
454
+
455
+ # --- Dependency for Endpoints ---
456
+ def get_tagger(app: FastAPI) -> Tagger:
457
+ """A dependency that provides the initialized tagger instance."""
458
+ if not app.state.tagger:
459
+ raise HTTPException(
460
+ status_code=503,
461
+ detail="Tagger is not initialized. The server may be starting up or has encountered an error.",
462
+ )
463
+ return app.state.tagger
464
+
465
+
466
+ # --- API Endpoints ---
467
+ def add_endpoints(app: FastAPI):
468
+ tagger_dependency = lambda: get_tagger(app)
469
+
470
+ @app.post("/", response_model=BatchTaggingResponse, summary="Tag a batch of images")
471
+ async def tag_batch(
472
+ tags_threshold: TaggerArgs = TaggerArgs(),
473
+ file: UploadFile = File(
474
+ ..., description="A .npz file containing a batch of images in NHWC format."
475
+ ),
476
+ ):
477
+ if not file.filename or not file.filename.endswith(".npz"):
478
+ raise HTTPException(
479
+ status_code=400,
480
+ detail="Only .npz files are supported for batch processing.",
481
+ )
482
+
483
+ start_time = time.time()
484
+ tagger = tagger_dependency()
485
+
486
+ logger.info(f"Processing batch file: {file.filename}")
487
+ contents = await file.read()
488
+ with np.load(BytesIO(contents)) as npz:
489
+ batch = npz[npz.files[0]]
490
+
491
+ logger.info(f"Loaded batch of shape: {batch.shape}")
492
+ process_start = time.time()
493
+ try:
494
+ results = tagger.predict_batch(batch, tags_threshold=tags_threshold)
495
+ except ValueError as e:
496
+ raise HTTPException(status_code=400, detail=str(e))
497
+ processing_time = time.time() - process_start
498
+ logger.info(f"Processed batch in {processing_time:.2f}s")
499
+
500
+ return BatchTaggingResponse(
501
+ batch_size=len(results),
502
+ results=results,
503
+ timing=Timing(
504
+ total_seconds=time.time() - start_time,
505
+ processing_seconds=processing_time,
506
+ ),
507
+ )
508
+
509
+ @app.get("/status", response_model=StatusResponse, summary="Get server status")
510
+ async def status():
511
+ tagger = tagger_dependency()
512
+ return StatusResponse(
513
+ status="ok",
514
+ model_name=tagger.model_repo,
515
+ )
516
+
517
+
518
+ def determine_type(field_type: type):
519
+ if type(field_type) is types.UnionType:
520
+ return typing.get_args(field_type)[0]
521
+ return field_type
522
+
523
+
524
+ # --- Main Execution ---
525
+ def main():
526
+ """Parses arguments, sets up the app, and runs the server."""
527
+ parser = argparse.ArgumentParser(description="Image Tagging Server")
528
+ # Add arguments that correspond to the Settings fields
529
+ for field_name, field in Settings.model_fields.items():
530
+ parser.add_argument(
531
+ f"--{field_name.replace('_', '-')}",
532
+ type=determine_type(field.annotation), # Basic type handling for argparse
533
+ default=field.default,
534
+ help=field.description,
535
+ )
536
+ args = parser.parse_args()
537
+
538
+ # Create settings from a combination of args, env vars, and defaults
539
+ settings = Settings(**vars(args))
540
+
541
+ global logger
542
+ logger = setup_logging(settings.log_level.upper())
543
+
544
+ if settings.token:
545
+ import os
546
+
547
+ logger.info("Using custom token...")
548
+ os.environ["HF_TOKEN"] = settings.token
549
+
550
+ app = create_app(settings)
551
+ add_endpoints(app)
552
+
553
+ uvicorn.run(
554
+ app,
555
+ host=settings.host,
556
+ port=settings.port,
557
+ log_config=None, # Use our custom logger
558
+ )
559
+
560
+
561
+ if __name__ == "__main__":
562
+ main()