import json
from dataclasses import asdict
from multiprocessing import shared_memory
from typing import Dict, List, Tuple

import numpy as np
from celery import Celery
from redis import ConnectionPool, Redis

import inference.enterprise.parallel.celeryconfig
from inference.core.entities.requests.inference import (
    InferenceRequest,
    request_from_type,
)
from inference.core.entities.responses.inference import InferenceResponse
from inference.core.env import REDIS_HOST, REDIS_PORT, STUB_CACHE_SIZE
from inference.core.managers.decorators.fixed_size_cache import WithFixedSizeCache
from inference.core.managers.decorators.locked_load import (
    LockedLoadModelManagerDecorator,
)
from inference.core.managers.stub_loader import StubLoaderManager
from inference.core.registries.roboflow import RoboflowModelRegistry
from inference.enterprise.parallel.utils import (
    SUCCESS_STATE,
    SharedMemoryMetadata,
    failure_handler,
    shm_manager,
)
from inference.models.utils import ROBOFLOW_MODEL_TYPES

pool = ConnectionPool(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True)
app = Celery("tasks", broker=f"redis://{REDIS_HOST}:{REDIS_PORT}")
app.config_from_object(inference.enterprise.parallel.celeryconfig)
model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES)
model_manager = StubLoaderManager(model_registry)
model_manager = WithFixedSizeCache(
    LockedLoadModelManagerDecorator(model_manager), max_size=STUB_CACHE_SIZE
)


@app.task(queue="pre")
def preprocess(request: Dict):
    redis_client = Redis(connection_pool=pool)
    with failure_handler(redis_client, request["id"]):
        model_manager.add_model(request["model_id"], request["api_key"])
        model_type = model_manager.get_task_type(request["model_id"])
        request = request_from_type(model_type, request)
        image, preprocess_return_metadata = model_manager.preprocess(
            request.model_id, request
        )
        # multi image requests are split into single image requests upstream and rebatched later
        image = image[0]
        request.image.value = None  # avoid writing image again since it's in memory
        shm = shared_memory.SharedMemory(create=True, size=image.nbytes)
        with shm_manager(shm):
            shared = np.ndarray(image.shape, dtype=image.dtype, buffer=shm.buf)
            shared[:] = image[:]
            shm_metadata = SharedMemoryMetadata(shm.name, image.shape, image.dtype.name)
            queue_infer_task(
                redis_client, shm_metadata, request, preprocess_return_metadata
            )


@app.task(queue="post")
def postprocess(
    shm_info_list: Tuple[Dict], request: Dict, preproc_return_metadata: Dict
):
    redis_client = Redis(connection_pool=pool)
    shm_info_list: List[SharedMemoryMetadata] = [
        SharedMemoryMetadata(**metadata) for metadata in shm_info_list
    ]
    with failure_handler(redis_client, request["id"]):
        with shm_manager(
            *[shm_metadata.shm_name for shm_metadata in shm_info_list],
            unlink_on_success=True,
        ) as shms:
            model_manager.add_model(request["model_id"], request["api_key"])
            model_type = model_manager.get_task_type(request["model_id"])
            request = request_from_type(model_type, request)

            outputs = load_outputs(shm_info_list, shms)

            request_dict = dict(**request.dict())
            model_id = request_dict.pop("model_id")

            response = model_manager.postprocess(
                model_id,
                outputs,
                preproc_return_metadata,
                **request_dict,
                return_image_dims=True,
            )[0]

            write_response(redis_client, response, request.id)


def load_outputs(
    shm_info_list: List[SharedMemoryMetadata], shms: List[shared_memory.SharedMemory]
) -> Tuple[np.ndarray, ...]:
    outputs = []
    for args, shm in zip(shm_info_list, shms):
        output = np.ndarray(
            [1] + args.array_shape, dtype=args.array_dtype, buffer=shm.buf
        )
        outputs.append(output)
    return tuple(outputs)


def queue_infer_task(
    redis: Redis,
    shm_metadata: SharedMemoryMetadata,
    request: InferenceRequest,
    preprocess_return_metadata: Dict,
):
    return_vals = {
        "shm_metadata": asdict(shm_metadata),
        "request": request.dict(),
        "preprocess_metadata": preprocess_return_metadata,
    }
    return_vals = json.dumps(return_vals)
    pipe = redis.pipeline()
    pipe.zadd(f"infer:{request.model_id}", {return_vals: request.start})
    pipe.hincrby(f"requests", request.model_id, 1)
    pipe.execute()


def write_response(redis: Redis, response: InferenceResponse, request_id: str):
    response = response.dict(exclude_none=True, by_alias=True)
    redis.publish(
        f"results",
        json.dumps(
            {"status": SUCCESS_STATE, "task_id": request_id, "payload": response}
        ),
    )