import logging
import time
from asyncio import Queue as AioQueue
from dataclasses import asdict
from multiprocessing import shared_memory
from queue import Queue
from threading import Thread
from typing import Dict, List, Tuple

import numpy as np
import orjson
from redis import ConnectionPool, Redis

from inference.core.entities.requests.inference import (
    InferenceRequest,
    request_from_type,
)
from inference.core.env import MAX_ACTIVE_MODELS, MAX_BATCH_SIZE, REDIS_HOST, REDIS_PORT
from inference.core.managers.base import ModelManager
from inference.core.managers.decorators.fixed_size_cache import WithFixedSizeCache
from inference.core.models.roboflow import RoboflowInferenceModel
from inference.core.registries.roboflow import RoboflowModelRegistry
from inference.enterprise.parallel.tasks import postprocess
from inference.enterprise.parallel.utils import (
    SharedMemoryMetadata,
    failure_handler,
    shm_manager,
)

logging.basicConfig(level=logging.WARNING)
logger = logging.getLogger()

from inference.models.utils import ROBOFLOW_MODEL_TYPES

BATCH_SIZE = MAX_BATCH_SIZE
if BATCH_SIZE == float("inf"):
    BATCH_SIZE = 32
AGE_TRADEOFF_SECONDS_FACTOR = 30


class InferServer:
    def __init__(self, redis: Redis) -> None:
        self.redis = redis
        model_registry = RoboflowModelRegistry(ROBOFLOW_MODEL_TYPES)
        model_manager = ModelManager(model_registry)
        self.model_manager = WithFixedSizeCache(
            model_manager, max_size=MAX_ACTIVE_MODELS
        )
        self.running = True
        self.response_queue = Queue()
        self.write_thread = Thread(target=self.write_responses)
        self.write_thread.start()
        self.batch_queue = Queue(maxsize=1)
        self.infer_thread = Thread(target=self.infer)
        self.infer_thread.start()

    def write_responses(self):
        while True:
            try:
                response = self.response_queue.get()
                write_infer_arrays_and_launch_postprocess(*response)
            except Exception as error:
                logger.warning(
                    f"Encountered error while writiing response:\n" + str(error)
                )

    def infer_loop(self):
        while self.running:
            try:
                model_names = get_requested_model_names(self.redis)
                if not model_names:
                    time.sleep(0.001)
                    continue
                self.get_batch(model_names)
            except Exception as error:
                logger.warning("Encountered error in infer loop:\n" + str(error))
                continue

    def infer(self):
        while True:
            model_id, images, batch, preproc_return_metadatas = self.batch_queue.get()
            outputs = self.model_manager.predict(model_id, images)
            for output, b, metadata in zip(
                zip(*outputs), batch, preproc_return_metadatas
            ):
                self.response_queue.put_nowait((output, b["request"], metadata))

    def get_batch(self, model_names):
        start = time.perf_counter()
        batch, model_id = get_batch(self.redis, model_names)
        logger.info(f"Inferring: model<{model_id}> batch_size<{len(batch)}>")
        with failure_handler(self.redis, *[b["request"]["id"] for b in batch]):
            self.model_manager.add_model(model_id, batch[0]["request"]["api_key"])
            model_type = self.model_manager.get_task_type(model_id)
            for b in batch:
                request = request_from_type(model_type, b["request"])
                b["request"] = request
                b["shm_metadata"] = SharedMemoryMetadata(**b["shm_metadata"])

            metadata_processed = time.perf_counter()
            logger.info(
                f"Took {(metadata_processed - start):3f} seconds to process metadata"
            )
            with shm_manager(
                *[b["shm_metadata"].shm_name for b in batch], unlink_on_success=True
            ) as shms:
                images, preproc_return_metadatas = load_batch(batch, shms)
                loaded = time.perf_counter()
                logger.info(
                    f"Took {(loaded - metadata_processed):3f} seconds to load batch"
                )
                self.batch_queue.put(
                    (model_id, images, batch, preproc_return_metadatas)
                )


def get_requested_model_names(redis: Redis) -> List[str]:
    request_counts = redis.hgetall("requests")
    model_names = [
        model_name for model_name, count in request_counts.items() if int(count) > 0
    ]
    return model_names


def get_batch(redis: Redis, model_names: List[str]) -> Tuple[List[Dict], str]:
    """
    Run a heuristic to select the best batch to infer on
    redis[Redis]: redis client
    model_names[List[str]]: list of models with nonzero number of requests
    returns:
        Tuple[List[Dict], str]
        List[Dict] represents a batch of request dicts
        str is the model id
    """
    batch_sizes = [
        RoboflowInferenceModel.model_metadata_from_memcache_endpoint(m)["batch_size"]
        for m in model_names
    ]
    batch_sizes = [b if not isinstance(b, str) else BATCH_SIZE for b in batch_sizes]
    batches = [
        redis.zrange(f"infer:{m}", 0, b - 1, withscores=True)
        for m, b in zip(model_names, batch_sizes)
    ]
    model_index = select_best_inference_batch(batches, batch_sizes)
    batch = batches[model_index]
    selected_model = model_names[model_index]
    redis.zrem(f"infer:{selected_model}", *[b[0] for b in batch])
    redis.hincrby(f"requests", selected_model, -len(batch))
    batch = [orjson.loads(b[0]) for b in batch]
    return batch, selected_model


def select_best_inference_batch(batches, batch_sizes):
    now = time.time()
    average_ages = [np.mean([float(b[1]) - now for b in batch]) for batch in batches]
    lengths = [
        len(batch) / batch_size for batch, batch_size in zip(batches, batch_sizes)
    ]
    fitnesses = [
        age / AGE_TRADEOFF_SECONDS_FACTOR + length
        for age, length in zip(average_ages, lengths)
    ]
    model_index = fitnesses.index(max(fitnesses))
    return model_index


def load_batch(
    batch: List[Dict[str, str]], shms: List[shared_memory.SharedMemory]
) -> Tuple[List[np.ndarray], List[Dict]]:
    images = []
    preproc_return_metadatas = []
    for b, shm in zip(batch, shms):
        shm_metadata: SharedMemoryMetadata = b["shm_metadata"]
        image = np.ndarray(
            shm_metadata.array_shape, dtype=shm_metadata.array_dtype, buffer=shm.buf
        ).copy()
        images.append(image)
        preproc_return_metadatas.append(b["preprocess_metadata"])
    return images, preproc_return_metadatas


def write_infer_arrays_and_launch_postprocess(
    arrs: Tuple[np.ndarray, ...],
    request: InferenceRequest,
    preproc_return_metadata: Dict,
):
    """Write inference results to shared memory and launch the postprocessing task"""
    shms = [shared_memory.SharedMemory(create=True, size=arr.nbytes) for arr in arrs]
    with shm_manager(*shms):
        shm_metadatas = []
        for arr, shm in zip(arrs, shms):
            shared = np.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf)
            shared[:] = arr[:]
            shm_metadata = SharedMemoryMetadata(
                shm_name=shm.name, array_shape=arr.shape, array_dtype=arr.dtype.name
            )
            shm_metadatas.append(asdict(shm_metadata))

        postprocess.s(
            tuple(shm_metadatas), request.dict(), preproc_return_metadata
        ).delay()


if __name__ == "__main__":
    pool = ConnectionPool(host=REDIS_HOST, port=REDIS_PORT, decode_responses=True)
    redis = Redis(connection_pool=pool)
    InferServer(redis).infer_loop()