|
import json |
|
from contextlib import contextmanager |
|
from dataclasses import dataclass |
|
from multiprocessing import shared_memory |
|
from typing import List, Union |
|
|
|
from redis import Redis |
|
|
|
SUCCESS_STATE = 1 |
|
FAILURE_STATE = -1 |
|
|
|
|
|
@contextmanager |
|
def failure_handler(redis: Redis, *request_ids: str): |
|
""" |
|
Context manager that updates the status/results key in redis with exception |
|
info on failure. |
|
""" |
|
try: |
|
yield |
|
except Exception as error: |
|
message = type(error).__name__ + ": " + str(error) |
|
for request_id in request_ids: |
|
redis.publish( |
|
"results", |
|
json.dumps( |
|
{"task_id": request_id, "status": FAILURE_STATE, "payload": message} |
|
), |
|
) |
|
raise |
|
|
|
|
|
@contextmanager |
|
def shm_manager( |
|
*shms: Union[str, shared_memory.SharedMemory], unlink_on_success: bool = False |
|
): |
|
"""Context manager that closes and frees shared memory objects.""" |
|
try: |
|
loaded_shms = [] |
|
for shm in shms: |
|
errors = [] |
|
try: |
|
if isinstance(shm, str): |
|
shm = shared_memory.SharedMemory(name=shm) |
|
loaded_shms.append(shm) |
|
except BaseException as error: |
|
errors.append(error) |
|
if errors: |
|
raise Exception(errors) |
|
|
|
yield loaded_shms |
|
except: |
|
for shm in loaded_shms: |
|
shm.close() |
|
shm.unlink() |
|
raise |
|
else: |
|
for shm in loaded_shms: |
|
shm.close() |
|
if unlink_on_success: |
|
shm.unlink() |
|
|
|
|
|
@dataclass |
|
class SharedMemoryMetadata: |
|
"""Info needed to load array from shared memory""" |
|
|
|
shm_name: str |
|
array_shape: List[int] |
|
array_dtype: str |
|
|