import struct
from enum import IntEnum
from fsspec.spec import AbstractBufferedFile
from typing import Any, Iterator, NamedTuple


class TokenType(IntEnum):
    NORMAL       = 1
    UNKNOWN      = 2
    CONTROL      = 3
    USER_DEFINED = 4
    UNUSED       = 5
    BYTE         = 6


class LlamaFileType(IntEnum):
    ALL_F32              = 0
    MOSTLY_F16           = 1
    MOSTLY_Q4_0          = 2
    MOSTLY_Q4_1          = 3
    MOSTLY_Q4_1_SOME_F16 = 4
    MOSTLY_Q4_2          = 5
    MOSTLY_Q4_3          = 6
    MOSTLY_Q8_0          = 7
    MOSTLY_Q5_0          = 8
    MOSTLY_Q5_1          = 9
    MOSTLY_Q2_K          = 10
    MOSTLY_Q3_K_S        = 11
    MOSTLY_Q3_K_M        = 12
    MOSTLY_Q3_K_L        = 13
    MOSTLY_Q4_K_S        = 14
    MOSTLY_Q4_K_M        = 15
    MOSTLY_Q5_K_S        = 16
    MOSTLY_Q5_K_M        = 17
    MOSTLY_Q6_K          = 18
    MOSTLY_IQ2_XXS       = 19
    MOSTLY_IQ2_XS        = 20
    MOSTLY_Q2_K_S        = 21
    MOSTLY_IQ3_XS        = 22
    MOSTLY_IQ3_XXS       = 23
    MOSTLY_IQ1_S         = 24
    MOSTLY_IQ4_NL        = 25
    MOSTLY_IQ3_S         = 26
    MOSTLY_IQ3_M         = 27
    MOSTLY_IQ2_S         = 28
    MOSTLY_IQ2_M         = 29
    MOSTLY_IQ4_XS        = 30
    MOSTLY_IQ1_M         = 31
    MOSTLY_BF16          = 32
    MOSTLY_Q4_0_4_4      = 33
    MOSTLY_Q4_0_4_8      = 34
    MOSTLY_Q4_0_8_8      = 35
    MOSTLY_TQ1_0         = 36
    MOSTLY_TQ2_0         = 37


class GGUFValueType(IntEnum):
    UINT8   = 0
    INT8    = 1
    UINT16  = 2
    INT16   = 3
    UINT32  = 4
    INT32   = 5
    FLOAT32 = 6
    BOOL    = 7
    STRING  = 8
    ARRAY   = 9
    UINT64  = 10
    INT64   = 11
    FLOAT64 = 12


standard_metadata = {
    "general.type": (GGUFValueType.STRING, "model"),
    "general.architecture": (GGUFValueType.STRING, "llama"),
    "general.quantization_version": (GGUFValueType.UINT32, 2),
    "general.alignment": (GGUFValueType.UINT32, 32),
    "general.file_type": (GGUFValueType.UINT32, 0),
    "general.name": (GGUFValueType.STRING, ""),
    "general.author": (GGUFValueType.STRING, ""),
    "general.version": (GGUFValueType.STRING, ""),
    "general.organization": (GGUFValueType.STRING, ""),
    "general.finetune": (GGUFValueType.STRING, ""),
    "general.basename": (GGUFValueType.STRING, ""),
    "general.description": (GGUFValueType.STRING, ""),
    "general.quantized_by": (GGUFValueType.STRING, ""),
    "general.size_label": (GGUFValueType.STRING, ""),
    "general.license": (GGUFValueType.STRING, ""),
    "general.license.name": (GGUFValueType.STRING, ""),
    "general.license.link": (GGUFValueType.STRING, ""),
    "general.url": (GGUFValueType.STRING, ""),
    "general.doi": (GGUFValueType.STRING, ""),
    "general.uuid": (GGUFValueType.STRING, ""),
    "general.repo_url": (GGUFValueType.STRING, ""),
    "general.source.url": (GGUFValueType.STRING, ""),
    "general.source.doi": (GGUFValueType.STRING, ""),
    "general.source.uuid": (GGUFValueType.STRING, ""),
    "general.source.repo_url": (GGUFValueType.STRING, ""),
    "general.tags": (GGUFValueType.STRING, []),
    "general.languages": (GGUFValueType.STRING, []),
    "general.datasets": (GGUFValueType.STRING, []),
    "split.no": (GGUFValueType.UINT16, 0),
    "split.count": (GGUFValueType.UINT16, 0),
    "split.tensors.count": (GGUFValueType.UINT32, 0),
    "tokenizer.ggml.model": (GGUFValueType.STRING, "gpt2"),
    "tokenizer.ggml.pre": (GGUFValueType.STRING, "llama-bpe"),
    "tokenizer.ggml.tokens": (GGUFValueType.STRING, []),
    "tokenizer.ggml.token_type": (GGUFValueType.INT32, []),
    "tokenizer.ggml.scores": (GGUFValueType.FLOAT32, []),
    "tokenizer.ggml.merges": (GGUFValueType.STRING, []),
    "tokenizer.ggml.bos_token_id": (GGUFValueType.UINT32, 0),
    "tokenizer.ggml.eos_token_id": (GGUFValueType.UINT32, 0),
    "tokenizer.ggml.unknown_token_id": (GGUFValueType.UINT32, 0),
    "tokenizer.ggml.seperator_token_id": (GGUFValueType.UINT32, 0),
    "tokenizer.ggml.padding_token_id": (GGUFValueType.UINT32, 0),
    "tokenizer.ggml.cls_token_id": (GGUFValueType.UINT32, 0),
    "tokenizer.ggml.mask_token_id": (GGUFValueType.UINT32, 0),
    "tokenizer.ggml.add_bos_token": (GGUFValueType.BOOL, False),
    "tokenizer.ggml.add_eos_token": (GGUFValueType.BOOL, False),
    "tokenizer.ggml.add_space_prefix": (GGUFValueType.BOOL, False),
    "tokenizer.ggml.remove_extra_whitespaces": (GGUFValueType.BOOL, False),
    "tokenizer.chat_template": (GGUFValueType.STRING, ""),
    "tokenizer.chat_template.rag": (GGUFValueType.STRING, ""),
    "tokenizer.chat_template.tool_use": (GGUFValueType.STRING, ""),
    "tokenizer.chat_templates": (GGUFValueType.STRING, []),
    "tokenizer.ggml.prefix_token_id": (GGUFValueType.UINT32, 0),
    "tokenizer.ggml.suffix_token_id": (GGUFValueType.UINT32, 0),
    "tokenizer.ggml.middle_token_id": (GGUFValueType.UINT32, 0),
    "tokenizer.ggml.eot_token_id": (GGUFValueType.UINT32, 0),
    "tokenizer.ggml.eom_token_id": (GGUFValueType.UINT32, 0),
    "tokenizer.ggml.fim_pre_token_id": (GGUFValueType.UINT32, 0),
    "tokenizer.ggml.fim_suf_token_id": (GGUFValueType.UINT32, 0),
    "tokenizer.ggml.fim_mid_token_id": (GGUFValueType.UINT32, 0),
    "tokenizer.ggml.fim_pad_token_id": (GGUFValueType.UINT32, 0),
    "tokenizer.ggml.fim_rep_token_id": (GGUFValueType.UINT32, 0),
    "tokenizer.ggml.fim_sep_token_id": (GGUFValueType.UINT32, 0),
    "quantize.imatrix.file": (GGUFValueType.STRING, ""),
    "quantize.imatrix.dataset": (GGUFValueType.STRING, ""),
    "quantize.imatrix.entries_count": (GGUFValueType.INT32, 0),
    "quantize.imatrix.chunks_count": (GGUFValueType.INT32, 0),
}


deprecated_metadata = {
    "tokenizer.ggml.prefix_token_id",
    "tokenizer.ggml.suffix_token_id",
    "tokenizer.ggml.middle_token_id",
}


gguf_scalar_size: dict[GGUFValueType, int] = {
    GGUFValueType.UINT8:   1,
    GGUFValueType.INT8:    1,
    GGUFValueType.UINT16:  2,
    GGUFValueType.INT16:   2,
    GGUFValueType.UINT32:  4,
    GGUFValueType.INT32:   4,
    GGUFValueType.FLOAT32: 4,
    GGUFValueType.BOOL:    1,
    GGUFValueType.UINT64:  8,
    GGUFValueType.INT64:   8,
    GGUFValueType.FLOAT64: 8,
}


gguf_scalar_pack: dict[GGUFValueType, str] = {
    GGUFValueType.UINT8:   "B",
    GGUFValueType.INT8:    "b",
    GGUFValueType.UINT16:  "H",
    GGUFValueType.INT16:   "h",
    GGUFValueType.UINT32:  "I",
    GGUFValueType.INT32:   "i",
    GGUFValueType.FLOAT32: "f",
    GGUFValueType.BOOL:    "?",
    GGUFValueType.UINT64:  "Q",
    GGUFValueType.INT64:   "q",
    GGUFValueType.FLOAT64: "d",
}


class GGUFData(NamedTuple):
    type: GGUFValueType | None
    value: Any
    data: bytes


class HuggingGGUFstream:
    fp: AbstractBufferedFile
    header: dict[str, GGUFData]
    metadata: dict[str, GGUFData]
    endian: str
    metaend: int
    filesize: int

    def __init__(
        self,
        fp: AbstractBufferedFile,
    ) -> None:
        self.fp = fp
        self.header = {}
        self.metadata = {}
        self.endian = '<'
        self.alignment = 32
        self.metaend = 0
        self.offset = 0
        self.filesize = fp.details.get('size')

        if (data := self.fp.read(4)) != b'GGUF':
            raise TypeError('File is not a GGUF')

        self.header['magic'] = GGUFData(
            type = None,
            value = None,
            data = data,
        )

        data = self._read_field(GGUFValueType.UINT32)
        if data.value != 3:
            if data.value == 3 << 24:
                data = GGUFData(
                    type = data.type,
                    value = 3,
                    data = data.data,
                )
                self.endian = '>'
            else:
                raise TypeError('Unsupported GGUF version')
        self.header['version'] = data

        data = self._read_field(GGUFValueType.UINT64)
        self.header['tensors'] = data

        data = self._read_field(GGUFValueType.UINT64)
        self.header['metadata'] = data

    def _unpack_field(
        self,
        buffer: bytes,
        field_type: GGUFValueType,
        repeat: int = 1,
    ) -> Any:
        value = struct.unpack(f'{self.endian}{repeat}{gguf_scalar_pack.get(field_type)}', buffer)
        return value[0] if repeat == 1 else value

    def _pack_field(
        self,
        field_type: GGUFValueType,
        *values,
    ) -> bytes:
        return struct.pack(f'{self.endian}{len(values)}{gguf_scalar_pack.get(field_type)}', *values)

    def _pack_value(
        self,
        val_type: GGUFValueType,
        value: Any,
    ) -> bytes:
        if isinstance(value, list):
            data = self._pack_field(GGUFValueType.UINT32, val_type)
            data += self._pack_field(GGUFValueType.UINT64, len(value))

        if val_type == GGUFValueType.ARRAY:
            raise TypeError('Array of arrays currently unsupported')
        elif val_type == GGUFValueType.STRING:
            if isinstance(value, list):
                for v in value:
                    buf = str(v).encode('utf-8')
                    data += self._pack_field(GGUFValueType.UINT64, len(buf))
                    data += buf
            else:
                buf = str(value).encode('utf-8')
                data = self._pack_field(GGUFValueType.UINT64, len(buf))
                data += buf
        elif val_type in gguf_scalar_pack:
            if isinstance(value, list):
                data += self._pack_field(val_type, *value)
            else:
                data = self._pack_field(val_type, value)
        else:
            raise TypeError('Unknown metadata type')

        return data

    def _read_field(
        self,
        field_type: GGUFValueType,
        repeat: int = 1,
    ) -> GGUFData:
        data = self.fp.read(gguf_scalar_size.get(field_type) * repeat)
        value = self._unpack_field(data, field_type, repeat = repeat)

        return GGUFData(
            type = field_type,
            value = value,
            data = data,
        )

    def _read_value(
        self,
        val_type: GGUFValueType,
    ) -> GGUFData:
        if val_type == GGUFValueType.ARRAY:
            data = self._read_field(GGUFValueType.UINT32)
            val_len = self._read_field(GGUFValueType.UINT64)

            if data.value in gguf_scalar_pack:
                val = self._read_field(data.value, repeat = val_len.value)
                data = GGUFData(
                    type = val.type,
                    value = list(val.value),
                    data = data.data + val_len.data + val.data,
                )
            else:
                v = []
                d = [data.data, val_len.data]

                for _ in range(val_len.value):
                    val = self._read_value(data.value)
                    d.append(val.data)
                    v.append(val.value)

                data = GGUFData(
                    type = data.value,
                    value = v,
                    data = b''.join(d),
                )
        elif val_type == GGUFValueType.STRING:
            data = self._read_field(GGUFValueType.UINT64)
            val = self.fp.read(data.value)
            data = GGUFData(
                type = val_type,
                value = val.decode('utf-8'),
                data = data.data + val,
            )
        elif val_type in gguf_scalar_pack:
            data = self._read_field(val_type)
        else:
            raise TypeError('Unknown metadata type')

        return data

    def _update_metacount(
        self,
    ) -> None:
        old_count = self.header['metadata']
        new_count = len(self.metadata)
        self.header['metadata'] = GGUFData(
            type = old_count.type,
            value = new_count,
            data = self._pack_field(old_count.type, new_count),
        )

    def read_metadata(
        self,
    ) -> Iterator[tuple[str, GGUFData]]:
        if self.metadata:
            for k, v in self.metadata.items():
                yield k, v
        else:
            num_metadata = self.header['metadata'].value

            for _ in range(num_metadata):
                key = self._read_value(GGUFValueType.STRING)
                val_type = self._read_field(GGUFValueType.UINT32)
                val = self._read_value(val_type.value)

                self.metadata[key.value] = val = GGUFData(
                    type = val.type,
                    value = val.value,
                    data = key.data + val_type.data + val.data,
                )

                yield key.value, val

            if (alignment := self.metadata.get('general.alignment')) is not None:
                self.alignment = alignment.value

            self.metaend = self.fp.loc
            self.offset = self.metaend % self.alignment

    def adjust_padding(
        self,
    ) -> None:
        if self.header['tensors'].value == 0:
            return

        dummy_key = 'dummy.padding'
        cur_metaend = 0

        for data in self.header.values():
            cur_metaend += len(data.data)

        for k, v in self.metadata.items():
            if k != dummy_key:
                cur_metaend += len(v.data)

        if (cur_metaend % self.alignment) != self.offset:
            dummy_len = 8 + len(dummy_key) + 4 + 4 + 8
            dummy_offset = (cur_metaend + dummy_len) % self.alignment
            dummy_padding = (32 + (self.offset - dummy_offset)) % self.alignment

            self.add_metadata(dummy_key, GGUFValueType.UINT8, [0] * dummy_padding)
        else:
            self.remove_metadata(dummy_key)

    def add_metadata(
        self,
        key: str,
        type: GGUFValueType,
        value: Any,
    ) -> None:
        data = self._pack_value(GGUFValueType.STRING, key)
        data += self._pack_field(GGUFValueType.UINT32, GGUFValueType.ARRAY if isinstance(value, list) else type)
        data += self._pack_value(type, value)

        if (meta := self.metadata.get(key)):
            self.filesize -= len(meta.data)

        self.filesize += len(data)
        self.metadata[key] = GGUFData(
            type = type,
            value = value,
            data = data,
        )

        if not meta:
            self._update_metacount()

    def remove_metadata(
        self,
        key: str,
    ) -> None:
        if (meta := self.metadata.get(key)):
            del self.metadata[key]

            self.filesize -= len(meta.data)
            self._update_metacount()