import json
import time

from datasets import Dataset
from requests_futures.sessions import FuturesSession
from transformers import AutoTokenizer

from defaults import (ADDRESS_BETTERTRANSFORMER, ADDRESS_VANILLA, HEADERS,
                      MODEL_NAME)

RETURN_MESSAGE_SINGLE = """
Inference statistics:

* Response status: {0}
* Prediction: {1}
* Inference latency (preprocessing/forward/postprocessing): {2} ms
* Peak GPU memory usage: {3} MB
* End-to-end latency (communication + pre/forward/post): {4} ms
* Padding ratio: 0.0 %
"""

RETURN_MESSAGE_SPAM = """
Processing inputs sent asynchronously. Grab a coffee.

Inference statistics:

* Throughput: {0} samples/s
* Mean inference latency (preprocessing/forward/postprocessing): {1} ms
* Mean peak GPU memory: {2} MB
* Mean padding ratio: {3} %
* Mean sequence length: {4} tokens
* Effective mean batch size: {5}
"""

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)


def get_message_single(
    status, prediction, inf_latency, peak_gpu_memory, end_to_end_latency, **kwargs
):
    return RETURN_MESSAGE_SINGLE.format(
        status, prediction, inf_latency, peak_gpu_memory, end_to_end_latency
    )


def get_message_spam(
    throughput,
    mean_inference_latency,
    mean_peak_gpu_memory,
    mean_padding_ratio,
    mean_sequence_length,
    effective_batch_size,
    **kwargs,
):
    return RETURN_MESSAGE_SPAM.format(
        throughput,
        mean_inference_latency,
        mean_peak_gpu_memory,
        mean_padding_ratio,
        mean_sequence_length,
        effective_batch_size,
    )


SESSION = FuturesSession()


def send_single(input_model_vanilla: str, address: str):
    assert address in [ADDRESS_VANILLA, ADDRESS_BETTERTRANSFORMER]

    # should not take more than 10 s, so timeout if that's the case
    inp = json.dumps({"text": input_model_vanilla, "pre_tokenized": False}).encode(
        "utf-8"
    )
    start = time.time()
    promise = SESSION.post(address, headers=HEADERS, data=inp, timeout=10)

    try:
        response = promise.result()  # resolve ASAP
        end = time.time()
    except Exception as e:
        return f"{e}"

    status = response.status_code

    response_text = json.loads(response.text)
    prediction = response_text[0]
    inf_latency = response_text[1]
    peak_gpu_memory = response_text[2]
    end_to_end_latency = round((end - start) * 1e3, 2)

    return get_message_single(
        status, prediction, inf_latency, peak_gpu_memory, end_to_end_latency
    )


def send_spam(inp: Dataset, address: str):
    assert address in [ADDRESS_VANILLA, ADDRESS_BETTERTRANSFORMER]

    mean_inference_latency = 0
    mean_peak_gpu_memory = 0

    n_pads = 0
    n_elems = 0
    sequence_length = 0
    effective_batch_size = 0

    promises = []

    n_inputs = len(inp)

    start = time.time()
    for i in range(n_inputs):
        input_data = inp[i]["sentence"].encode("utf-8")

        # should not take more than 15 s, so timeout if that's the case
        promises.append(
            SESSION.post(address, headers=HEADERS, data=input_data, timeout=15)
        )

    # to measure throughput first
    end = 0
    for promise in promises:
        try:
            response = promise.result()  # resolve ASAP
        except Exception as e:
            return f"{e}"

        end = max(time.time(), end)

    # then other metrics
    for promise in promises:
        response = promise.result()
        response_text = json.loads(response.text)

        mean_inference_latency += response_text[1]
        mean_peak_gpu_memory += response_text[2]
        n_pads += response_text[3]
        n_elems += response_text[4]
        sequence_length += response_text[5]
        effective_batch_size += response_text[6]

    throughput = n_inputs / (end - start)
    mean_padding_ratio = f"{n_pads / n_elems * 100:.2f}"
    mean_sequence_length = sequence_length / n_inputs
    effective_batch_size = effective_batch_size / n_inputs

    throughput = round(throughput, 2)
    mean_inference_latency = round(mean_inference_latency / n_inputs, 2)
    mean_peak_gpu_memory = round(mean_peak_gpu_memory / n_inputs, 2)

    return get_message_spam(
        throughput,
        mean_inference_latency,
        mean_peak_gpu_memory,
        mean_padding_ratio,
        mean_sequence_length,
        effective_batch_size,
    )