from typing import Callable, Generator, Iterator, List, Optional, Union
import ctypes
from ctypes import (
    c_bool,
    c_char_p,
    c_int,
    c_int8,
    c_int32,
    c_uint8,
    c_uint32,
    c_size_t,
    c_float,
    c_double,
    c_void_p,
    POINTER,
    _Pointer,  # type: ignore
    Structure,
    Array,
)
import pathlib
import os
import sys

# Load the library
def _load_shared_library(lib_base_name: str):
    # Construct the paths to the possible shared library names
    _base_path = pathlib.Path(os.path.abspath(os.path.dirname(__file__)))
    # Searching for the library in the current directory under the name "libllama2" (default name
    # for llama2.cu) and "llama" (default name for this repo)
    _lib_paths: List[pathlib.Path] = []
    # Determine the file extension based on the platform
    if sys.platform.startswith("linux"):
        _lib_paths += [
            _base_path / f"lib{lib_base_name}.so",
        ]
    else:
        raise RuntimeError("Unsupported platform")

    if "LLAMA2_CU_LIB" in os.environ:
        lib_base_name = os.environ["LLAMA2_CU_LIB"]
        _lib = pathlib.Path(lib_base_name)
        _base_path = _lib.parent.resolve()
        _lib_paths = [_lib.resolve()]

    cdll_args = dict()  # type: ignore
    # Add the library directory to the DLL search path on Windows (if needed)

    # Try to load the shared library, handling potential errors
    for _lib_path in _lib_paths:
        if _lib_path.exists():
            try:
                return ctypes.CDLL(str(_lib_path), **cdll_args)
            except Exception as e:
                raise RuntimeError(f"Failed to load shared library '{_lib_path}': {e}")

    raise FileNotFoundError(
        f"Shared library with base name '{lib_base_name}' not found"
    )

# Specify the base name of the shared library to load
_lib_base_name = "llama2"

# Load the library
_lib = _load_shared_library(_lib_base_name)


def llama2_init(model_path: str, tokenizer_path: str) -> c_void_p:
    return _lib.llama2_init(model_path.encode('utf-8'), tokenizer_path.encode('utf-8'))

_lib.llama2_init.argtypes = [c_char_p, c_char_p]
_lib.llama2_init.restype = c_void_p

def llama2_free(ctx: c_void_p) -> None:
    _lib.llama2_free(ctx)

_lib.llama2_free.argtypes = [c_void_p]
_lib.llama2_free.restype = None

def llama2_generate(ctx: c_void_p, prompt: str, max_tokens: int, temperature: float, top_p: float, seed: int) -> int:
    return _lib.llama2_generate(ctx, prompt.encode('utf-8'), max_tokens, temperature, top_p, seed)

_lib.llama2_generate.argtypes = [c_void_p, c_char_p, c_int, c_float, c_float, c_int]
_lib.llama2_generate.restype = c_int

def llama2_get_last(ctx: c_void_p) -> bytes:
    return _lib.llama2_get_last(ctx)    # bytes or None

_lib.llama2_get_last.argtypes = [c_void_p]
_lib.llama2_get_last.restype = c_char_p

def llama2_tokenize(ctx: c_void_p, text: str, add_bos: bool, add_eos: bool) -> List[int]:
    tokens = (c_int * (len(text) + 3))()
    n_tokens = (c_int * 1)()
    _lib.llama2_tokenize(ctx, text.encode('utf-8'), add_bos, add_eos, tokens, n_tokens)
    return tokens[:n_tokens[0]]

_lib.llama2_tokenize.argtypes = [c_void_p, c_char_p, c_int8, c_int8, POINTER(c_int), POINTER(c_int)]
_lib.llama2_tokenize.restype = None

class Llama2:
    def __init__(
        self, 
        model_path: str,
        tokenizer_path: str='tokenizer.bin',
        n_ctx: int = 0,
        n_batch: int = 0) -> None:
        self.n_ctx = n_ctx
        self.n_batch = n_batch
        self.llama2_ctx = llama2_init(model_path, tokenizer_path)

    def tokenize(
        self, text: str, add_bos: bool = True, add_eos: bool = False
    ) -> List[int]:
        return llama2_tokenize(self.llama2_ctx, text, add_bos, add_eos)
    
    def __call__(
        self,
        prompt: str,
        max_tokens: int = 128,
        temperature: float = 0.8,
        top_p: float = 0.95,
        min_p: float = 0.05,
        typical_p: float = 1.0,
        logprobs: Optional[int] = None,
        frequency_penalty: float = 0.0,
        presence_penalty: float = 0.0,
        repeat_penalty: float = 1.1,
        top_k: int = 40,
        stream: bool = False,
        seed: Optional[int] = None,
    ) -> Iterator[str]:
        if seed is None:
            seed = 42
        ret = llama2_generate(self.llama2_ctx, prompt, max_tokens, temperature, top_p, seed)
        if ret != 0:
            raise RuntimeError(f"Failed to launch generation for prompt '{prompt}'")
        bytes_buffer = b''  # store generated bytes until decoded (in case of multi-byte characters)
        while True:
            result = llama2_get_last(self.llama2_ctx)
            if result is None:
                break
            bytes_buffer += result
            try:
                string = bytes_buffer.decode('utf-8')
            except UnicodeDecodeError:
                pass
            else:
                bytes_buffer = b''
                yield string