import json
import random
import torch
from typing import Any
from typing import Optional
from fastapi import FastAPI
from pydantic import BaseModel
from vllm import LLM, SamplingParams, RequestOutput


# Don't forget to set HF_TOKEN in the env during running

cuda_num_device: int = 0
if torch.cuda.is_available() == 'cuda':
    random_seed = 42
    random.seed(random_seed)

    device = torch.device('cuda')
    torch.cuda.manual_seed(random_seed)

    print(f"Using device: {device}")
    print(f"CUDA available and enabled. {torch.cuda}")
    print(f"CUDA is available: {torch.cuda.is_available()}")
    print(f"CUDA device count: {torch.cuda.device_count()}")
    print(f"CUDA current device: {torch.cuda.current_device()}")

    for i in range(torch.cuda.device_count()):
        print('=================================================================')
        print(torch.cuda.get_device_name(i))
        print('Memory Usage:')
        print('Allocated:', round(torch.cuda.memory_allocated(i) / 1024 ** 3, 1), 'GB')
        print('Cached:   ', round(torch.cuda.memory_reserved(i) / 1024 ** 3, 1), 'GB')
app = FastAPI()

# Initialize the LLM engine
# Replace 'your-model-path' with the actual path or name of your model
# example:
# https://huggingface.co/spaces/damienbenveniste/deploy_vLLM/blob/b210a934d4ff7b68254d42fa28736d74649e610d/app.py#L17-L20

engine_llama_3_2: LLM = LLM(
    model='meta-llama/Llama-3.2-3B-Instruct',
    revision="0cb88a4f764b7a12671c53f0838cd831a0843b95",
    # https://github.com/vllm-project/vllm/blob/v0.6.4/vllm/config.py#L1062-L1065
    max_num_batched_tokens=32768,  # Reduced for T4, must equal with max_model_len
    max_num_seqs=16,               # Reduced for T4
    gpu_memory_utilization=0.85,   # Slightly increased, adjust if needed
    tensor_parallel_size=1,

    # Llama-3.2-3B-Instruct max context length is 131072, but we reduce it to 32k.
    # 32k tokens, 3/4 of 32k is 24k words, each page average is 500 or 0.5k words,
    # so that's basically 24k / .5k = 24 x 2 =~48 pages.
    # Because when we use maximum token length, it will be slower and the memory is not enough for T4.
    # https://github.com/vllm-project/vllm/blob/v0.6.4/vllm/config.py#L85-L86
    # https://github.com/vllm-project/vllm/blob/v0.6.4/vllm/config.py#L98-L102
    # [rank0]:     raise ValueError(
    # [rank0]: ValueError: The model's max seq len (131072)
    #   is larger than the maximum number of tokens that can be stored in KV cache (57056).
    #   Try increasing `gpu_memory_utilization` or decreasing `max_model_len` when initializing the engine.
    max_model_len=32768,           # Reduced for T4
    enforce_eager=True,            # Disable CUDA graph

    # File "/home/user/.local/lib/python3.12/site-packages/vllm/worker/worker.py",
    # line 479, in _check_if_gpu_supports_dtype
    # Bfloat16 is only supported on GPUs with compute capability of at least 8.0.
    # Your Tesla T4 GPU has compute capability 7.5.
    # You can use float16 instead by explicitly setting the`dtype` flag in CLI, for example: --dtype=half.
    dtype='half',                  # Use 'half' for T4
)

# # ValueError: max_num_batched_tokens (512) is smaller than max_model_len (32768).
# # This effectively limits the maximum sequence length to max_num_batched_tokens and makes vLLM reject longer sequences.
# # Please increase max_num_batched_tokens or decrease max_model_len.
# engine_sailor_chat: LLM = LLM(
#     model='sail/Sailor-4B-Chat',
#     revision="89a866a7041e6ec023dd462adeca8e28dd53c83e",
#     max_num_batched_tokens=32768,    # Reduced for T4
#     max_num_seqs=16,                 # Reduced for T4
#     gpu_memory_utilization=0.85,     # Slightly increased, adjust if needed
#     tensor_parallel_size=1,
#     max_model_len=32768,
#     enforce_eager=True,              # Disable CUDA graph
#     dtype='half',                    # Use 'half' for T4
# )


@app.get("/")
def greet_json():
    cuda_info: dict[str, Any] = {}
    if torch.cuda.is_available():
        cuda_current_device: int = torch.cuda.current_device()
        cuda_info = {
            "device_count": torch.cuda.device_count(),
            "cuda_device": torch.cuda.get_device_name(cuda_current_device),
            "cuda_capability": torch.cuda.get_device_capability(cuda_current_device),
            "allocated":  f"{round(torch.cuda.memory_allocated(cuda_current_device) / 1024 ** 3, 1)} GB",
            "cached": f"{round(torch.cuda.memory_reserved(cuda_current_device) / 1024 ** 3, 1)} GB",
        }

    return {
        "message": f"CUDA availability is {torch.cuda.is_available()}",
        "cuda_info": cuda_info,
        "model": [
            {
                "name": "meta-llama/Llama-3.2-3B-Instruct",
                "revision": "0cb88a4f764b7a12671c53f0838cd831a0843b95",
                "max_model_len": engine_llama_3_2.llm_engine.model_config.max_model_len,
            },
        ]
    }


class GenerationRequest(BaseModel):
    prompt: str
    max_tokens: int = 100
    temperature: float = 0.7
    logit_bias: Optional[dict[int, float]] = None


class GenerationResponse(BaseModel):
    text: Optional[str]
    error: Optional[str]


@app.post("/generate-llama3-2")
def generate_text(request: GenerationRequest) -> dict[str, Any]:
    try:
        sampling_params: SamplingParams = SamplingParams(
            temperature=request.temperature,
            max_tokens=request.max_tokens,
            logit_bias=request.logit_bias,
        )

        # Generate text
        response: list[RequestOutput] = engine_llama_3_2.generate(
            prompts=request.prompt,
            sampling_params=sampling_params
        )

        output: dict[str, Any] = {}
        for item in response:
            outputs: list[dict[str, Any]] = []
            for out in item.outputs:
                outputs.append({
                    "text": out.text,
                })
            output["output"] = outputs

        return {
            "output": output,
        }

    except Exception as e:
        return {
            "error": str(e)
        }


# @app.post("/generate-sailor-chat")
# def generate_text(request: GenerationRequest) -> list[RequestOutput] | dict[str, str]:
#     try:
#         sampling_params: SamplingParams = SamplingParams(
#             temperature=request.temperature,
#             max_tokens=request.max_tokens,
#             logit_bias=request.logit_bias,
#         )
#
#         # Generate text
#         return engine_sailor_chat.generate(
#             prompts=request.prompt,
#             sampling_params=sampling_params
#         )
#
#     except Exception as e:
#         return {
#             "error": str(e)
#         }
#