|
import json |
|
from dataclasses import asdict |
|
from multiprocessing import shared_memory |
|
from typing import Dict, List, Tuple |
|
|
|
import numpy as np |
|
from celery import Celery |
|
from redis import ConnectionPool, Redis |
|
|
|
import inference.enterprise.parallel.celeryconfig |
|
from inference.core.entities.requests.inference import ( |
|
InferenceRequest, |
|
request_from_type, |
|
) |
|
from inference.core.entities.responses.inference import InferenceResponse |
|
from inference.core.env import REDIS_HOST, REDIS_PORT, STUB_CACHE_SIZE |
|
from inference.core.managers.decorators.fixed_size_cache import WithFixedSizeCache |
|
from inference.core.managers.decorators.locked_load import ( |
|
LockedLoadModelManagerDecorator, |
|
) |
|
from inference.core.managers.stub_loader import StubLoaderManager |
|
from inference.core.registries.roboflow import RoboflowModelRegistry |
|
from inference.enterprise.parallel.utils import ( |
|
SUCCESS_STATE, |
|
SharedMemoryMetadata, |
|
failure_handler, |
|
shm_manager, |
|
) |
|
from inference.models.utils import ROBOFLOW_MODEL_TYPES |
|
|
|
pool = ConnectionPool(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True) |
|
app = Celery("tasks", broker=f"redis://{REDIS_HOST}:{REDIS_PORT}") |
|
app.config_from_object(inference.enterprise.parallel.celeryconfig) |
|
model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES) |
|
model_manager = StubLoaderManager(model_registry) |
|
model_manager = WithFixedSizeCache( |
|
LockedLoadModelManagerDecorator(model_manager), max_size=STUB_CACHE_SIZE |
|
) |
|
|
|
|
|
@app.task(queue="pre") |
|
def preprocess(request: Dict): |
|
redis_client = Redis(connection_pool=pool) |
|
with failure_handler(redis_client, request["id"]): |
|
model_manager.add_model(request["model_id"], request["api_key"]) |
|
model_type = model_manager.get_task_type(request["model_id"]) |
|
request = request_from_type(model_type, request) |
|
image, preprocess_return_metadata = model_manager.preprocess( |
|
request.model_id, request |
|
) |
|
|
|
image = image[0] |
|
request.image.value = None |
|
shm = shared_memory.SharedMemory(create=True, size=image.nbytes) |
|
with shm_manager(shm): |
|
shared = np.ndarray(image.shape, dtype=image.dtype, buffer=shm.buf) |
|
shared[:] = image[:] |
|
shm_metadata = SharedMemoryMetadata(shm.name, image.shape, image.dtype.name) |
|
queue_infer_task( |
|
redis_client, shm_metadata, request, preprocess_return_metadata |
|
) |
|
|
|
|
|
@app.task(queue="post") |
|
def postprocess( |
|
shm_info_list: Tuple[Dict], request: Dict, preproc_return_metadata: Dict |
|
): |
|
redis_client = Redis(connection_pool=pool) |
|
shm_info_list: List[SharedMemoryMetadata] = [ |
|
SharedMemoryMetadata(**metadata) for metadata in shm_info_list |
|
] |
|
with failure_handler(redis_client, request["id"]): |
|
with shm_manager( |
|
*[shm_metadata.shm_name for shm_metadata in shm_info_list], |
|
unlink_on_success=True, |
|
) as shms: |
|
model_manager.add_model(request["model_id"], request["api_key"]) |
|
model_type = model_manager.get_task_type(request["model_id"]) |
|
request = request_from_type(model_type, request) |
|
|
|
outputs = load_outputs(shm_info_list, shms) |
|
|
|
request_dict = dict(**request.dict()) |
|
model_id = request_dict.pop("model_id") |
|
|
|
response = model_manager.postprocess( |
|
model_id, |
|
outputs, |
|
preproc_return_metadata, |
|
**request_dict, |
|
return_image_dims=True, |
|
)[0] |
|
|
|
write_response(redis_client, response, request.id) |
|
|
|
|
|
def load_outputs( |
|
shm_info_list: List[SharedMemoryMetadata], shms: List[shared_memory.SharedMemory] |
|
) -> Tuple[np.ndarray, ...]: |
|
outputs = [] |
|
for args, shm in zip(shm_info_list, shms): |
|
output = np.ndarray( |
|
[1] + args.array_shape, dtype=args.array_dtype, buffer=shm.buf |
|
) |
|
outputs.append(output) |
|
return tuple(outputs) |
|
|
|
|
|
def queue_infer_task( |
|
redis: Redis, |
|
shm_metadata: SharedMemoryMetadata, |
|
request: InferenceRequest, |
|
preprocess_return_metadata: Dict, |
|
): |
|
return_vals = { |
|
"shm_metadata": asdict(shm_metadata), |
|
"request": request.dict(), |
|
"preprocess_metadata": preprocess_return_metadata, |
|
} |
|
return_vals = json.dumps(return_vals) |
|
pipe = redis.pipeline() |
|
pipe.zadd(f"infer:{request.model_id}", {return_vals: request.start}) |
|
pipe.hincrby(f"requests", request.model_id, 1) |
|
pipe.execute() |
|
|
|
|
|
def write_response(redis: Redis, response: InferenceResponse, request_id: str): |
|
response = response.dict(exclude_none=True, by_alias=True) |
|
redis.publish( |
|
f"results", |
|
json.dumps( |
|
{"status": SUCCESS_STATE, "task_id": request_id, "payload": response} |
|
), |
|
) |
|
|