import os
import sys
import ctypes
import pathlib
from typing import Optional, List
import enum
from pathlib import Path

class DataType(enum.IntEnum):
    def __str__(self):
        return str(self.name)
    
    F16 = 0
    F32 = 1
    I32 = 2
    L64 = 3
    Q4_0 = 4
    Q4_1 = 5
    Q5_0 = 6
    Q5_1 = 7
    Q8_0 = 8
    Q8_1 = 9
    Q2_K = 10
    Q3_K = 11
    Q4_K = 12
    Q5_K = 13
    Q6_K = 14
    Q8_K = 15

class Verbosity(enum.IntEnum):
    SILENT = 0
    ERR = 1
    INFO = 2
    DEBUG = 3

class ImageFormat(enum.IntEnum):
    UNKNOWN = 0
    F32 = 1
    U8 = 2

I32 = ctypes.c_int32
U32 = ctypes.c_uint32
F32 = ctypes.c_float
SIZE_T = ctypes.c_size_t
VOID_PTR = ctypes.c_void_p
CHAR_PTR = ctypes.POINTER(ctypes.c_char)
FLOAT_PTR = ctypes.POINTER(ctypes.c_float)
INT_PTR = ctypes.POINTER(ctypes.c_int32)
CHAR_PTR_PTR = ctypes.POINTER(ctypes.POINTER(ctypes.c_char))

MiniGPT4ContextP = VOID_PTR
class MiniGPT4Context:
    def __init__(self, ptr: ctypes.pointer):
        self.ptr = ptr

class MiniGPT4Image(ctypes.Structure):
    _fields_ = [
        ('data', VOID_PTR),
        ('width', I32),
        ('height', I32),
        ('channels', I32),
        ('format', I32)
    ]

class MiniGPT4Embedding(ctypes.Structure):
    _fields_ = [
        ('data', FLOAT_PTR),
        ('n_embeddings', SIZE_T),
    ]

MiniGPT4ImageP = ctypes.POINTER(MiniGPT4Image)
MiniGPT4EmbeddingP = ctypes.POINTER(MiniGPT4Embedding)

class MiniGPT4SharedLibrary:
    """
    Python wrapper around minigpt4.cpp shared library.
    """

    def __init__(self, shared_library_path: str):
        """
        Loads the shared library from specified file.
        In case of any error, this method will throw an exception.

        Parameters
        ----------
        shared_library_path : str
            Path to minigpt4.cpp shared library. On Windows, it would look like 'minigpt4.dll'. On UNIX, 'minigpt4.so'.
        """

        self.library = ctypes.cdll.LoadLibrary(shared_library_path)

        self.library.minigpt4_model_load.argtypes = [
            CHAR_PTR, # const char *path
            CHAR_PTR, # const char *llm_model
            I32, # int verbosity
            I32, # int seed
            I32, # int n_ctx
            I32, # int n_batch
            I32, # int numa
        ]
        self.library.minigpt4_model_load.restype = MiniGPT4ContextP

        self.library.minigpt4_image_load_from_file.argtypes = [
            MiniGPT4ContextP, # struct MiniGPT4Context *ctx
            CHAR_PTR, # const char *path
            MiniGPT4ImageP, # struct MiniGPT4Image *image
            I32, # int flags
        ]
        self.library.minigpt4_image_load_from_file.restype = I32

        self.library.minigpt4_encode_image.argtypes = [
            MiniGPT4ContextP, # struct MiniGPT4Context *ctx
            MiniGPT4ImageP, # const struct MiniGPT4Image *image
            MiniGPT4EmbeddingP, # struct MiniGPT4Embedding *embedding
            I32, # size_t n_threads
        ]
        self.library.minigpt4_encode_image.restype = I32

        self.library.minigpt4_begin_chat_image.argtypes = [
            MiniGPT4ContextP, # struct MiniGPT4Context *ctx
            MiniGPT4EmbeddingP, # struct MiniGPT4Embedding *embedding
            CHAR_PTR, # const char *s
            I32, # size_t n_threads
        ]
        self.library.minigpt4_begin_chat_image.restype = I32

        self.library.minigpt4_end_chat_image.argtypes = [
            MiniGPT4ContextP, # struct MiniGPT4Context *ctx
            CHAR_PTR_PTR, # const char **token
            I32, # size_t n_threads
            F32, # float temp
            I32, # int32_t top_k
            F32, # float top_p
            F32, # float tfs_z
            F32, # float typical_p
            I32, # int32_t repeat_last_n
            F32, # float repeat_penalty
            F32, # float alpha_presence
            F32, # float alpha_frequency
            I32, # int mirostat
            F32, # float mirostat_tau
            F32, # float mirostat_eta
            I32, # int penalize_nl
        ]
        self.library.minigpt4_end_chat_image.restype = I32

        self.library.minigpt4_system_prompt.argtypes = [
            MiniGPT4ContextP, # struct MiniGPT4Context *ctx
            I32, # size_t n_threads
        ]
        self.library.minigpt4_system_prompt.restype = I32

        self.library.minigpt4_begin_chat.argtypes = [
            MiniGPT4ContextP, # struct MiniGPT4Context *ctx
            CHAR_PTR, # const char *s
            I32, # size_t n_threads
        ]
        self.library.minigpt4_begin_chat.restype = I32

        self.library.minigpt4_end_chat.argtypes = [
            MiniGPT4ContextP, # struct MiniGPT4Context *ctx
            CHAR_PTR_PTR, # const char **token
            I32, # size_t n_threads
            F32, # float temp
            I32, # int32_t top_k
            F32, # float top_p
            F32, # float tfs_z
            F32, # float typical_p
            I32, # int32_t repeat_last_n
            F32, # float repeat_penalty
            F32, # float alpha_presence
            F32, # float alpha_frequency
            I32, # int mirostat
            F32, # float mirostat_tau
            F32, # float mirostat_eta
            I32, # int penalize_nl
        ]
        self.library.minigpt4_end_chat.restype = I32

        self.library.minigpt4_reset_chat.argtypes = [
            MiniGPT4ContextP, # struct MiniGPT4Context *ctx
        ]
        self.library.minigpt4_reset_chat.restype = I32

        self.library.minigpt4_contains_eos_token.argtypes = [
            CHAR_PTR, # const char *s
        ]
        self.library.minigpt4_contains_eos_token.restype = I32

        self.library.minigpt4_is_eos.argtypes = [
            CHAR_PTR, # const char *s
        ]
        self.library.minigpt4_is_eos.restype = I32

        self.library.minigpt4_free.argtypes = [
            MiniGPT4ContextP, # struct MiniGPT4Context *ctx
        ]
        self.library.minigpt4_free.restype = I32

        self.library.minigpt4_free_image.argtypes = [
            MiniGPT4ImageP, # struct MiniGPT4Image *image
        ]
        self.library.minigpt4_free_image.restype = I32

        self.library.minigpt4_free_embedding.argtypes = [
            MiniGPT4EmbeddingP, # struct MiniGPT4Embedding *embedding
        ]
        self.library.minigpt4_free_embedding.restype = I32

        self.library.minigpt4_error_code_to_string.argtypes = [
            I32, # int error_code
        ]
        self.library.minigpt4_error_code_to_string.restype = CHAR_PTR

        self.library.minigpt4_quantize_model.argtypes = [
            CHAR_PTR, # const char *in_path
            CHAR_PTR, # const char *out_path
            I32, # int data_type
        ]
        self.library.minigpt4_quantize_model.restype = I32

        self.library.minigpt4_set_verbosity.argtypes = [
            I32, # int verbosity
        ]
        self.library.minigpt4_set_verbosity.restype = None

    def panic_if_error(self, error_code: int) -> None:
        """
        Raises an exception if the error code is not 0.

        Parameters
        ----------
        error_code : int
            Error code to check.
        """

        if error_code != 0:
            raise RuntimeError(self.library.minigpt4_error_code_to_string(I32(error_code)))

    def minigpt4_model_load(self, model_path: str, llm_model_path: str, verbosity: int = 1, seed: int = 1337, n_ctx: int = 2048, n_batch: int = 512, numa: int = 0) -> MiniGPT4Context:
        """
        Loads a model from a file.

        Args:
            model_path (str): Path to model file.
            llm_model_path (str): Path to LLM model file.
            verbosity (int): Verbosity level: 0 = silent, 1 = error, 2 = info, 3 = debug. Defaults to 0.
            n_ctx (int): Size of context for llm model. Defaults to 2048.
            seed (int): Seed for llm model. Defaults to 1337.
            numa (int): NUMA node to use (0 = NUMA disabled, 1 = NUMA enabled). Defaults to 0.

        Returns:
            MiniGPT4Context: Context.
        """

        ptr = self.library.minigpt4_model_load(
            model_path.encode('utf-8'),
            llm_model_path.encode('utf-8'),
            I32(verbosity),
            I32(seed),
            I32(n_ctx),
            I32(n_batch),
            I32(numa),
        )

        assert ptr is not None, 'minigpt4_model_load failed'

        return MiniGPT4Context(ptr)

    def minigpt4_image_load_from_file(self, ctx: MiniGPT4Context, path: str, flags: int) -> MiniGPT4Image:
        """
        Loads an image from a file

        Args:
            ctx (MiniGPT4Context): context
            path (str): path
            flags (int): flags

        Returns:
            MiniGPT4Image: image
        """

        image = MiniGPT4Image()
        self.panic_if_error(self.library.minigpt4_image_load_from_file(ctx.ptr, path.encode('utf-8'), ctypes.pointer(image), I32(flags)))
        return image

    def minigpt4_preprocess_image(self, ctx: MiniGPT4Context, image: MiniGPT4Image, flags: int = 0) -> MiniGPT4Image:
        """
        Preprocesses an image

        Args:
            ctx (MiniGPT4Context): Context
            image (MiniGPT4Image): Image
            flags (int): Flags. Defaults to 0.

        Returns:
            MiniGPT4Image: Preprocessed image
        """

        preprocessed_image = MiniGPT4Image()
        self.panic_if_error(self.library.minigpt4_preprocess_image(ctx.ptr, ctypes.pointer(image), ctypes.pointer(preprocessed_image), I32(flags)))
        return preprocessed_image

    def minigpt4_encode_image(self, ctx: MiniGPT4Context, image: MiniGPT4Image, n_threads: int = 0) -> MiniGPT4Embedding:
        """
        Encodes an image into embedding

        Args:
            ctx (MiniGPT4Context): Context.
            image (MiniGPT4Image): Image.
            n_threads (int): Number of threads to use, if 0, uses all available. Defaults to 0.

        Returns:
            embedding (MiniGPT4Embedding): Output embedding.
        """

        embedding = MiniGPT4Embedding()
        self.panic_if_error(self.library.minigpt4_encode_image(ctx.ptr, ctypes.pointer(image), ctypes.pointer(embedding), n_threads))
        return embedding

    def minigpt4_begin_chat_image(self, ctx: MiniGPT4Context, image_embedding: MiniGPT4Embedding, s: str, n_threads: int = 0):
        """
        Begins a chat with an image.

        Args:
            ctx (MiniGPT4Context): Context.
            image_embedding (MiniGPT4Embedding): Image embedding.
            s (str): Question to ask about the image.
            n_threads (int, optional): Number of threads to use, if 0, uses all available. Defaults to 0.

        Returns:
            None
        """

        self.panic_if_error(self.library.minigpt4_begin_chat_image(ctx.ptr, ctypes.pointer(image_embedding), s.encode('utf-8'), n_threads))

    def minigpt4_end_chat_image(self, ctx: MiniGPT4Context, n_threads: int = 0, temp: float = 0.8, top_k: int = 40, top_p: float = 0.9, tfs_z: float = 1.0, typical_p: float = 1.0, repeat_last_n: int = 64, repeat_penalty: float = 1.1, alpha_presence: float = 1.0, alpha_frequency: float = 1.0, mirostat: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 1.0, penalize_nl: int = 1) -> str:
        """
        Ends a chat with an image.

        Args:
            ctx (MiniGPT4Context): Context.
            n_threads (int, optional): Number of threads to use, if 0, uses all available. Defaults to 0.
            temp (float, optional): Temperature. Defaults to 0.8.
            top_k (int, optional): Top K. Defaults to 40.
            top_p (float, optional): Top P. Defaults to 0.9.
            tfs_z (float, optional): Tfs Z. Defaults to 1.0.
            typical_p (float, optional): Typical P. Defaults to 1.0.
            repeat_last_n (int, optional): Repeat last N. Defaults to 64.
            repeat_penalty (float, optional): Repeat penality. Defaults to 1.1.
            alpha_presence (float, optional): Alpha presence. Defaults to 1.0.
            alpha_frequency (float, optional): Alpha frequency. Defaults to 1.0.
            mirostat (int, optional): Mirostat. Defaults to 0.
            mirostat_tau (float, optional): Mirostat Tau. Defaults to 5.0.
            mirostat_eta (float, optional): Mirostat Eta. Defaults to 1.0.
            penalize_nl (int, optional): Penalize NL. Defaults to 1.

        Returns:
            str: Token generated.
        """

        token = CHAR_PTR()
        self.panic_if_error(self.library.minigpt4_end_chat_image(ctx.ptr, ctypes.pointer(token), n_threads, temp, top_k, top_p, tfs_z, typical_p, repeat_last_n, repeat_penalty, alpha_presence, alpha_frequency, mirostat, mirostat_tau, mirostat_eta, penalize_nl))
        return ctypes.cast(token, ctypes.c_char_p).value.decode('utf-8')

    def minigpt4_system_prompt(self, ctx: MiniGPT4Context, n_threads: int = 0):
        """
        Generates a system prompt.

        Args:
            ctx (MiniGPT4Context): Context.
            n_threads (int, optional): Number of threads to use, if 0, uses all available. Defaults to 0.
        """

        self.panic_if_error(self.library.minigpt4_system_prompt(ctx.ptr, n_threads))

    def minigpt4_begin_chat(self, ctx: MiniGPT4Context, s: str, n_threads: int = 0):
        """
        Begins a chat continuing after minigpt4_begin_chat_image

        Args:
            ctx (MiniGPT4Context): Context.
            s (str): Question to ask about the image.
            n_threads (int, optional): Number of threads to use, if 0, uses all available. Defaults to 0.

        Returns:
            None
        """
        self.panic_if_error(self.library.minigpt4_begin_chat(ctx.ptr, s.encode('utf-8'), n_threads))

    def minigpt4_end_chat(self, ctx: MiniGPT4Context, n_threads: int = 0, temp: float = 0.8, top_k: int = 40, top_p: float = 0.9, tfs_z: float = 1.0, typical_p: float = 1.0, repeat_last_n: int = 64, repeat_penalty: float = 1.1, alpha_presence: float = 1.0, alpha_frequency: float = 1.0, mirostat: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 1.0, penalize_nl: int = 1) -> str:
        """
        Ends a chat.

        Args:
            ctx (MiniGPT4Context): Context.
            n_threads (int, optional): Number of threads to use, if 0, uses all available. Defaults to 0.
            temp (float, optional): Temperature. Defaults to 0.8.
            top_k (int, optional): Top K. Defaults to 40.
            top_p (float, optional): Top P. Defaults to 0.9.
            tfs_z (float, optional): Tfs Z. Defaults to 1.0.
            typical_p (float, optional): Typical P. Defaults to 1.0.
            repeat_last_n (int, optional): Repeat last N. Defaults to 64.
            repeat_penalty (float, optional): Repeat penality. Defaults to 1.1.
            alpha_presence (float, optional): Alpha presence. Defaults to 1.0.
            alpha_frequency (float, optional): Alpha frequency. Defaults to 1.0.
            mirostat (int, optional): Mirostat. Defaults to 0.
            mirostat_tau (float, optional): Mirostat Tau. Defaults to 5.0.
            mirostat_eta (float, optional): Mirostat Eta. Defaults to 1.0.
            penalize_nl (int, optional): Penalize NL. Defaults to 1.

        Returns:
            str: Token generated.
        """

        token = CHAR_PTR()
        self.panic_if_error(self.library.minigpt4_end_chat(ctx.ptr, ctypes.pointer(token), n_threads, temp, top_k, top_p, tfs_z, typical_p, repeat_last_n, repeat_penalty, alpha_presence, alpha_frequency, mirostat, mirostat_tau, mirostat_eta, penalize_nl))
        return ctypes.cast(token, ctypes.c_char_p).value.decode('utf-8')

    def minigpt4_reset_chat(self, ctx: MiniGPT4Context):
        """
        Resets the chat.

        Args:
            ctx (MiniGPT4Context): Context.
        """
        self.panic_if_error(self.library.minigpt4_reset_chat(ctx.ptr))

    def minigpt4_contains_eos_token(self, s: str) -> bool:

        """
        Checks if a string contains an EOS token.

        Args:
            s (str): String to check.
        
        Returns:
            bool: True if the string contains an EOS token, False otherwise.
        """

        return self.library.minigpt4_contains_eos_token(s.encode('utf-8'))

    def minigpt4_is_eos(self, s: str) -> bool:

        """
        Checks if a string is EOS.

        Args:
            s (str): String to check.
        
        Returns:
            bool: True if the string contains an EOS, False otherwise.
        """

        return self.library.minigpt4_is_eos(s.encode('utf-8'))


    def minigpt4_free(self, ctx: MiniGPT4Context) -> None:
        """
        Frees a context.

        Args:
            ctx (MiniGPT4Context): Context.
        """

        self.panic_if_error(self.library.minigpt4_free(ctx.ptr))

    def minigpt4_free_image(self, image: MiniGPT4Image) -> None:
        """
        Frees an image.

        Args:
            image (MiniGPT4Image): Image.
        """

        self.panic_if_error(self.library.minigpt4_free_image(ctypes.pointer(image)))

    def minigpt4_free_embedding(self, embedding: MiniGPT4Embedding) -> None:
        """
        Frees an embedding.

        Args:
            embedding (MiniGPT4Embedding): Embedding.
        """

        self.panic_if_error(self.library.minigpt4_free_embedding(ctypes.pointer(embedding)))

    def minigpt4_error_code_to_string(self, error_code: int) -> str:
        """
        Converts an error code to a string.

        Args:
            error_code (int): Error code.

        Returns:
            str: Error string.
        """

        return self.library.minigpt4_error_code_to_string(error_code).decode('utf-8')

    def minigpt4_quantize_model(self, in_path: str, out_path: str, data_type: DataType):
        """
        Quantizes a model file.

        Args:
            in_path (str): Path to input model file.
            out_path (str): Path to write output model file.
            data_type (DataType): Must be one DataType enum values.
        """

        self.panic_if_error(self.library.minigpt4_quantize_model(in_path.encode('utf-8'), out_path.encode('utf-8'), data_type))

    def minigpt4_set_verbosity(self, verbosity: Verbosity):
        """
        Sets verbosity.

        Args:
            verbosity (int): Verbosity.
        """

        self.library.minigpt4_set_verbosity(I32(verbosity))

def load_library() -> MiniGPT4SharedLibrary:
    """
    Attempts to find minigpt4.cpp shared library and load it.
    """

    file_name: str

    if 'win32' in sys.platform or 'cygwin' in sys.platform:
        file_name = 'minigpt4.dll'
    elif 'darwin' in sys.platform:
        file_name = 'libminigpt4.dylib'
    else:
        file_name = 'libminigpt4.so'

    cwd = pathlib.Path(os.getcwd())
    repo_root_dir: pathlib.Path = pathlib.Path(os.path.abspath(__file__)).parent.parent

    paths = [
        # If we are in "minigpt4" directory
        f'../bin/Release/{file_name}',
        # If we are in repo root directory
        f'bin/Release/{file_name}',
        # If we compiled in build directory
        f'build/bin/Release/{file_name}',
        # If we compiled in build directory
        f'build/{file_name}',
        f'../build/{file_name}',
        # Search relative to this file
        str(repo_root_dir / 'bin' / 'Release' / file_name),
        # Fallback
        str(repo_root_dir / file_name),
        str(cwd / file_name)
    ]

    for path in paths:
        if os.path.isfile(path):
            return MiniGPT4SharedLibrary(path)

    return MiniGPT4SharedLibrary(paths[-1])

class MiniGPT4ChatBot:
    def __init__(self, model_path: str, llm_model_path: str, verbosity: Verbosity = Verbosity.SILENT, n_threads: int = 0):
        """
        Creates a new MiniGPT4ChatBot instance.

        Args:
            model_path (str): Path to model file.
            llm_model_path (str): Path to language model model file.
            verbosity (Verbosity, optional): Verbosity. Defaults to Verbosity.SILENT.
            n_threads (int, optional): Number of threads to use. Defaults to 0.
        """
            
        self.library = load_library()
        self.ctx = self.library.minigpt4_model_load(model_path, llm_model_path, verbosity)
        self.n_threads = n_threads

        from PIL import Image
        from torchvision import transforms
        from torchvision.transforms.functional import InterpolationMode
        self.image_size = 224

        mean = (0.48145466, 0.4578275, 0.40821073)
        std = (0.26862954, 0.26130258, 0.27577711)
        self.transform = transforms.Compose(
            [
                transforms.RandomResizedCrop(
                    self.image_size,
                    interpolation=InterpolationMode.BICUBIC,
                ),
                transforms.ToTensor(),
                transforms.Normalize(mean, std)
            ]
        )
        self.embedding: Optional[MiniGPT4Embedding] = None
        self.is_image_chat = False
        self.chat_history = []

    def free(self):
        if self.ctx:
            self.library.minigpt4_free(self.ctx)

    def generate(self, message: str, limit: int = 1024, temp: float = 0.8, top_k: int = 40, top_p: float = 0.9, tfs_z: float = 1.0, typical_p: float = 1.0, repeat_last_n: int = 64, repeat_penalty: float = 1.1, alpha_presence: float = 1.0, alpha_frequency: float = 1.0, mirostat: int = 0, mirostat_tau: float = 5.0, mirostat_eta: float = 1.0, penalize_nl: int = 1):
        """
        Generates a chat response.

        Args:
            message (str): Message.
            limit (int, optional): Limit. Defaults to 1024.
            temp (float, optional): Temperature. Defaults to 0.8.
            top_k (int, optional): Top K. Defaults to 40.
            top_p (float, optional): Top P. Defaults to 0.9.
            tfs_z (float, optional): TFS Z. Defaults to 1.0.
            typical_p (float, optional): Typical P. Defaults to 1.0.
            repeat_last_n (int, optional): Repeat last N. Defaults to 64.
            repeat_penalty (float, optional): Repeat penalty. Defaults to 1.1.
            alpha_presence (float, optional): Alpha presence. Defaults to 1.0.
            alpha_frequency (float, optional): Alpha frequency. Defaults to 1.0.
            mirostat (int, optional): Mirostat. Defaults to 0.
            mirostat_tau (float, optional): Mirostat tau. Defaults to 5.0.
            mirostat_eta (float, optional): Mirostat eta. Defaults to 1.0.
            penalize_nl (int, optional): Penalize NL. Defaults to 1.
        """
        if self.is_image_chat:
            self.is_image_chat = False
            self.library.minigpt4_begin_chat_image(self.ctx, self.embedding, message, self.n_threads)
            chat = ''
            for _ in range(limit):
                token = self.library.minigpt4_end_chat_image(self.ctx, self.n_threads, temp, top_k, top_p, tfs_z, typical_p, repeat_last_n, repeat_penalty, alpha_presence, alpha_frequency, mirostat, mirostat_tau, mirostat_eta, penalize_nl)
                chat += token
                if self.library.minigpt4_contains_eos_token(token):
                    continue
                if self.library.minigpt4_is_eos(chat):
                    break
                yield token
        else:
            self.library.minigpt4_begin_chat(self.ctx, message, self.n_threads)
            chat = ''
            for _ in range(limit):
                token = self.library.minigpt4_end_chat(self.ctx, self.n_threads, temp, top_k, top_p, tfs_z, typical_p, repeat_last_n, repeat_penalty, alpha_presence, alpha_frequency, mirostat, mirostat_tau, mirostat_eta, penalize_nl)
                chat += token
                if self.library.minigpt4_contains_eos_token(token):
                    continue
                if self.library.minigpt4_is_eos(chat):
                    break
                yield token

    def reset_chat(self):
        """
        Resets the chat.
        """

        self.is_image_chat = False
        if self.embedding:
            self.library.minigpt4_free_embedding(self.embedding)
            self.embedding = None

        self.library.minigpt4_reset_chat(self.ctx)
        self.library.minigpt4_system_prompt(self.ctx, self.n_threads)

    def upload_image(self, image):
        """
        Uploads an image.
        
        Args:
            image (Image): Image.
        """

        self.reset_chat()

        image = self.transform(image)
        image = image.unsqueeze(0)
        image = image.numpy()
        image = image.ctypes.data_as(ctypes.c_void_p)
        minigpt4_image = MiniGPT4Image(image, self.image_size, self.image_size, 3, ImageFormat.F32)
        self.embedding = self.library.minigpt4_encode_image(self.ctx, minigpt4_image, self.n_threads)
        
        self.is_image_chat = True


if __name__ == "__main__":
    import argparse
    parser = argparse.ArgumentParser(description='Test loading minigpt4')
    parser.add_argument('model_path', help='Path to model file')
    parser.add_argument('llm_model_path', help='Path to llm model file')
    parser.add_argument('-i', '--image_path', help='Image to test', default='images/llama.png')
    parser.add_argument('-p', '--prompts', help='Text to test', default='what is the text in the picture?,what is the color of it?')
    args = parser.parse_args()

    model_path = args.model_path
    llm_model_path = args.llm_model_path
    image_path = args.image_path
    prompts = args.prompts

    if not Path(model_path).exists():
        print(f'Model does not exist: {model_path}')
        exit(1) 

    if not Path(llm_model_path).exists():
        print(f'LLM Model does not exist: {llm_model_path}')
        exit(1)

    prompts = prompts.split(',')

    print('Loading minigpt4 shared library...')
    library = load_library()
    print(f'Loaded library {library}')
    ctx = library.minigpt4_model_load(model_path, llm_model_path, Verbosity.DEBUG)
    image = library.minigpt4_image_load_from_file(ctx, image_path, 0)
    preprocessed_image = library.minigpt4_preprocess_image(ctx, image, 0)

    question = prompts[0]
    n_threads = 0
    embedding = library.minigpt4_encode_image(ctx, preprocessed_image, n_threads)
    library.minigpt4_system_prompt(ctx, n_threads)
    library.minigpt4_begin_chat_image(ctx, embedding, question, n_threads)
    chat = ''
    while True:
        token = library.minigpt4_end_chat_image(ctx, n_threads)
        chat += token
        if library.minigpt4_contains_eos_token(token):
            continue
        if library.minigpt4_is_eos(chat):
            break
        print(token, end='')

    for i in range(1, len(prompts)):
        prompt = prompts[i]
        library.minigpt4_begin_chat(ctx, prompt, n_threads)
        chat  = ''
        while True:
            token = library.minigpt4_end_chat(ctx, n_threads)
            chat += token
            if library.minigpt4_contains_eos_token(token):
                continue
            if library.minigpt4_is_eos(chat):
                break
            print(token, end='')

    library.minigpt4_free_image(image)
    library.minigpt4_free_image(preprocessed_image)
    library.minigpt4_free(ctx)