Spaces:
Sleeping
Sleeping
Update server.py
Browse files
server.py
CHANGED
|
@@ -27,6 +27,7 @@ import timm
|
|
| 27 |
import torch
|
| 28 |
import uvicorn
|
| 29 |
from fastapi import FastAPI, File, HTTPException, UploadFile
|
|
|
|
| 30 |
from PIL import Image
|
| 31 |
from pydantic import BaseModel, Field
|
| 32 |
from pydantic_settings import BaseSettings
|
|
@@ -447,6 +448,8 @@ def create_app(settings: Settings) -> FastAPI:
|
|
| 447 |
description="An API for tagging images using an ONNX model.",
|
| 448 |
version="1.0.1", # Incremented version
|
| 449 |
lifespan=lifespan,
|
|
|
|
|
|
|
| 450 |
)
|
| 451 |
app.state = AppState(settings)
|
| 452 |
return app
|
|
@@ -467,7 +470,20 @@ def get_tagger(app: FastAPI) -> Tagger:
|
|
| 467 |
def add_endpoints(app: FastAPI):
|
| 468 |
tagger_dependency = lambda: get_tagger(app)
|
| 469 |
|
| 470 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
async def tag_batch(
|
| 472 |
tags_threshold: TaggerArgs = TaggerArgs(),
|
| 473 |
file: UploadFile = File(
|
|
@@ -506,6 +522,7 @@ def add_endpoints(app: FastAPI):
|
|
| 506 |
),
|
| 507 |
)
|
| 508 |
|
|
|
|
| 509 |
@app.get("/status", response_model=StatusResponse, summary="Get server status")
|
| 510 |
async def status():
|
| 511 |
tagger = tagger_dependency()
|
|
|
|
| 27 |
import torch
|
| 28 |
import uvicorn
|
| 29 |
from fastapi import FastAPI, File, HTTPException, UploadFile
|
| 30 |
+
from fastapi.responses import RedirectResponse
|
| 31 |
from PIL import Image
|
| 32 |
from pydantic import BaseModel, Field
|
| 33 |
from pydantic_settings import BaseSettings
|
|
|
|
| 448 |
description="An API for tagging images using an ONNX model.",
|
| 449 |
version="1.0.1", # Incremented version
|
| 450 |
lifespan=lifespan,
|
| 451 |
+
docs_url="/docs",
|
| 452 |
+
|
| 453 |
)
|
| 454 |
app.state = AppState(settings)
|
| 455 |
return app
|
|
|
|
| 470 |
def add_endpoints(app: FastAPI):
|
| 471 |
tagger_dependency = lambda: get_tagger(app)
|
| 472 |
|
| 473 |
+
# Root welcome/docs page
|
| 474 |
+
@app.get("/", include_in_schema=False)
|
| 475 |
+
async def root():
|
| 476 |
+
if app.docs_url:
|
| 477 |
+
return RedirectResponse(url=app.docs_url)
|
| 478 |
+
elif app.redoc_url:
|
| 479 |
+
return RedirectResponse(url=app.redoc_url)
|
| 480 |
+
return HTMLResponse(
|
| 481 |
+
content="<h1>Welcome to the Tagger API</h1><p>Use /batch to tag images.</p>",
|
| 482 |
+
status_code=200
|
| 483 |
+
)
|
| 484 |
+
|
| 485 |
+
# Tagging endpoint at /batch
|
| 486 |
+
@app.post("/batch", response_model=BatchTaggingResponse, summary="Tag a batch of images")
|
| 487 |
async def tag_batch(
|
| 488 |
tags_threshold: TaggerArgs = TaggerArgs(),
|
| 489 |
file: UploadFile = File(
|
|
|
|
| 522 |
),
|
| 523 |
)
|
| 524 |
|
| 525 |
+
# Status endpoint
|
| 526 |
@app.get("/status", response_model=StatusResponse, summary="Get server status")
|
| 527 |
async def status():
|
| 528 |
tagger = tagger_dependency()
|