from copy import deepcopy
from typing import Optional, Tuple

import torch
from flash_attn import flash_attn_func
from transformers.modeling_outputs import CausalLMOutput

from ..ops.streaming_kernel import TritonMultiStageDotProductionAttention


class CudaCache:
    def __init__(self, num_units, unit_size, dtype):
        self.num_units = num_units
        self.unit_size = unit_size
        self.dtype = dtype
        self.data = torch.empty((num_units, unit_size), device="cuda", dtype=dtype)
        self.idle_set = set(list(range(num_units)))

    def alloc(self):
        assert len(self.idle_set) > 0

        idx = self.idle_set.pop()
        return self.data[idx], idx

    def delete(self, idx):
        assert idx not in self.idle_set
        self.idle_set.add(idx)


class MemoryUnit:
    def __init__(
        self,
        kv: Tuple[torch.Tensor, torch.Tensor],
        cache: CudaCache,
        load_to_cache: bool = False,
        pin_memory: bool = False,
    ):
        self.cache = cache

        if kv[0].is_cuda:
            cpu_data = tuple(_t.contiguous().to("cpu", non_blocking=True) for _t in kv)
        else:
            cpu_data = tuple(_t.contiguous() for _t in kv)

        if pin_memory:
            cpu_data = tuple(_t.pin_memory() for _t in cpu_data)

        if load_to_cache:
            gpu_data, gpu_data_id = cache.alloc()
            gpu_data = gpu_data.view((2,) + kv[0].shape)
            gpu_data[0].copy_(kv[0], non_blocking=True)
            gpu_data[1].copy_(kv[1], non_blocking=True)
            event = torch.cuda.Event()
            event.record(torch.cuda.current_stream())
        else:
            gpu_data, gpu_data_id = None, None
            event = None

        self.cpu_data = cpu_data
        self.gpu_data = gpu_data
        self.gpu_data_id = gpu_data_id
        self.event = event

    def load(self, target: Optional[Tuple[torch.Tensor, torch.Tensor]] = None) -> bool:
        if self.gpu_data is not None:
            if target is not None:
                target[0].copy_(self.gpu_data[0], non_blocking=True)
                target[1].copy_(self.gpu_data[1], non_blocking=True)
                target_event = torch.cuda.Event()
                target_event.record(torch.cuda.current_stream())
            else:
                target_event = None

            return False, target_event

        gpu_data, gpu_data_id = self.cache.alloc()
        gpu_data = gpu_data.view((2,) + self.cpu_data[0].shape)
        if target is not None:
            target[0].copy_(self.cpu_data[0], non_blocking=True)
            target[1].copy_(self.cpu_data[1], non_blocking=True)
            target_event = torch.cuda.Event()
            target_event.record(torch.cuda.current_stream())
            gpu_data[0].copy_(target[0], non_blocking=True)
            gpu_data[1].copy_(target[1], non_blocking=True)

        else:
            gpu_data[0].copy_(self.cpu_data[0], non_blocking=True)
            gpu_data[1].copy_(self.cpu_data[1], non_blocking=True)

        event = torch.cuda.Event()
        event.record(torch.cuda.current_stream())
        self.event = event
        self.gpu_data = gpu_data
        self.gpu_data_id = gpu_data_id

        return True, target_event

    def get(self):
        assert self.gpu_data is not None
        self.event.wait()
        return self.gpu_data

    def offload(self):
        assert self.gpu_data is not None
        self.event.wait()
        self.gpu_data = None
        self.cache.delete(self.gpu_data_id)
        self.gpu_data_id = None


class VectorTensor:
    def __init__(self, hidden_size, element_dtype):
        init_cached_size = 16
        self.data = torch.empty(
            (init_cached_size, hidden_size), dtype=element_dtype, device="cuda"
        )
        self.length = 0
        self.cache_size = init_cached_size
        self.hidden_size = hidden_size

    def append_cache(self):
        new_cache_size = self.cache_size * 2
        data_shape = self.data.shape
        new_data = torch.empty(
            (new_cache_size,) + data_shape[1:], device="cuda", dtype=self.data.dtype
        )
        new_data[: self.cache_size, ...].copy_(self.data)
        self.data = new_data
        self.cache_size = new_cache_size

    def append(self, tensor: torch.Tensor):
        assert tensor.dtype == self.data.dtype
        assert tensor.size(1) == self.hidden_size
        assert tensor.is_contiguous()

        append_l = tensor.size(0)

        while self.length + append_l > self.cache_size:
            self.append_cache()

        self.data[self.length : self.length + append_l, ...].copy_(tensor)

        self.length += append_l

    def get_data(self):
        return self.data[: self.length, ...]

    def get_topk(self, tensor: torch.Tensor, topk):  # inner product
        assert tensor.dim() == 1 and tensor.size(0) == self.hidden_size
        logits = torch.matmul(self.data[: self.length], tensor[:, None]).squeeze(dim=-1)
        assert logits.dim() == 1 and logits.size(0) == self.length
        return logits.topk(topk, dim=0).indices.cpu().tolist()

    def __len__(self):
        return self.length


class Faiss:
    def __init__(self, hidden_size, element_dtype):
        import faiss

        # We use the CPU index here because the GPU index requires a long initialization time
        self.index = faiss.IndexFlatIP(hidden_size)
        self.hidden_size = hidden_size

    def append(self, tensor: torch.Tensor):
        assert tensor.dim() == 2 and tensor.size(1) == self.hidden_size
        self.index.add(tensor.cpu().float().numpy().astype("float32"))

    def get_data(self):
        raise ValueError

    def get_topk(self, tensor: torch.Tensor, topk):
        assert tensor.dim() == 1 and tensor.size(0) == self.hidden_size
        xq = tensor[None, :].cpu().float().numpy().astype("float32")
        topk_index = self.index.search(xq, topk)[1][0].tolist()
        return topk_index

    def __len__(self):
        return self.index.ntotal


GLOBAL_STREAM = None


class ContextManager:
    def __init__(
        self,
        position_embedding,
        n_init,
        n_local,
        block_size,
        max_cached_block,
        topk,
        exc_block_size,
        score_decay: Optional[float] = None,
        repr_topk: int = 1,
        cache_strategy="lru",
        chunk_topk_calc: Optional[int] = None,
        async_global_stream: bool = False,
        pin_memory: bool = False,
        faiss: bool = False,
        perhead: bool = False,
        dense_decoding: bool = False,
    ):
        self.length = 0
        self.position_embedding = position_embedding
        self.n_init = n_init
        self.n_local = n_local
        self.block_size = block_size
        self.max_cached_block = max_cached_block
        self.exc_block_size = exc_block_size
        self.score_decay = score_decay
        assert exc_block_size <= n_local  # no global token in input
        self.topk = topk
        self.Attn = TritonMultiStageDotProductionAttention
        self.initialized = False
        self.repr_topk = repr_topk
        self.cache_strategy = cache_strategy
        self.load_count = 0
        self.chunk_topk_calc = chunk_topk_calc
        self.async_global_stream = async_global_stream
        self.pin_memory = pin_memory
        self.faiss = faiss
        self.perhead = perhead

        self.dense_decoding = dense_decoding

        global GLOBAL_STREAM
        if self.async_global_stream and GLOBAL_STREAM is None:
            GLOBAL_STREAM = torch.cuda.Stream()

        assert cache_strategy in ["lru", "lru-s"]

        if cache_strategy == "lru-s":
            self.calc_block_score = True
        else:
            self.calc_block_score = False

    def remove_lru_blocks(
        self, u, num_remove: Optional[int] = None, ignore_blocks=None
    ):
        if num_remove is None:
            num_remove = len(self.cached_blocks[u]) - self.max_cached_block

        if num_remove <= 0:
            return

        lst = list(self.cached_blocks[u].items())
        lst.sort(key=lambda x: x[1])

        removed = 0
        for i in range(len(lst)):
            idx = lst[i][0]
            if ignore_blocks is None or (idx not in ignore_blocks):
                self.global_blocks[u][idx].offload()
                self.cached_blocks[u].pop(idx)
                removed += 1

            if removed >= num_remove:
                return

    def get_block_k(self, k, score):
        assert isinstance(score, torch.Tensor)
        assert k.dim() >= 2
        k = self.from_group_kv(k)
        assert k.shape[:-1] == score.shape
        assert k.shape[-2] == self.block_size
        score_topk = score.topk(self.repr_topk, dim=-1).indices
        assert score_topk.shape == (self.num_units, self.unit_size, self.repr_topk)
        ret = torch.gather(
            k,
            -2,
            score_topk[:, :, :, None].expand(
                self.num_units, self.unit_size, self.repr_topk, self.dim_head
            ),
        )
        return ret

    def from_group_kv(self, tensor):
        assert tensor.dim() == 4
        assert tensor.size(1) == self.num_heads_kv
        if self.num_heads == self.num_heads_kv:
            return tensor
        _, _, length, dim_head = tensor.shape
        num_group = self.num_heads // self.num_heads_kv
        tensor = tensor.view((self.num_units, self.unit_size_kv, 1, length, dim_head))
        tensor = tensor.expand(
            (self.num_units, self.unit_size_kv, num_group, length, dim_head)
        ).reshape((self.num_units, self.num_heads, length, dim_head))
        return tensor

    def init(self, local_q, local_k, local_v, global_q, global_k, global_v):
        assert local_q.dim() == 4
        batch_size, num_heads, len_q, dim_head = local_q.shape
        num_heads_kv = local_k.size(1)

        for _t in [local_q, local_k, local_v, global_q, global_k, global_v]:
            assert _t.size(0) == batch_size
            assert _t.size(1) == num_heads or _t.size(1) == num_heads_kv
            assert _t.size(2) == len_q
            assert _t.size(3) == dim_head
            assert _t.is_cuda

        self.batch_size = batch_size
        self.num_heads = num_heads
        self.num_heads_kv = num_heads_kv
        self.dim_head = dim_head
        self.num_units = batch_size
        self.unit_size = num_heads
        self.unit_size_kv = num_heads_kv

        self.global_blocks = [[] for _ in range(self.num_units)]  # [[memory_unit]]
        self.cached_blocks = [
            {} for _ in range(self.num_units)
        ]  # [[block_id: block_score]
        self.num_global_block = 0

        if self.faiss:
            self.block_k = [
                Faiss(dim_head * self.unit_size, global_k.dtype)
                for _ in range(self.num_units)
            ]
        else:
            self.block_k = [
                VectorTensor(dim_head * self.unit_size, global_k.dtype)
                for _ in range(self.num_units)
            ]

        self.local_k = torch.empty(
            (self.num_units, self.unit_size_kv, 0, dim_head),
            dtype=local_k.dtype,
            device=local_k.device,
        )
        self.local_v = torch.empty(
            (self.num_units, self.unit_size_kv, 0, dim_head),
            dtype=local_v.dtype,
            device=local_v.device,
        )

        if self.dense_decoding:
            self.dense_k = torch.empty(
                (self.num_units, self.unit_size_kv, 0, dim_head),
                dtype=local_k.dtype,
                device=local_k.device,
            )
            self.dense_v = torch.empty(
                (self.num_units, self.unit_size_kv, 0, dim_head),
                dtype=local_v.dtype,
                device=local_v.device,
            )

        self.global_remainder = (
            torch.empty(
                (self.num_units, self.unit_size_kv, 0, dim_head),
                dtype=global_k.dtype,
                device=global_k.device,
            ),
            torch.empty(
                (self.num_units, self.unit_size_kv, 0, dim_head),
                dtype=global_v.dtype,
                device=global_v.device,
            ),
        )

        self.global_remainder_local_score = torch.empty(
            (self.num_units, self.unit_size, 0),
            dtype=global_k.dtype,
            device=global_k.device,
        )

        self.init_k = torch.empty(
            (self.num_units, self.unit_size_kv, 0, dim_head),
            dtype=global_k.dtype,
            device=global_k.device,
        )
        self.init_v = torch.empty(
            (self.num_units, self.unit_size_kv, 0, dim_head),
            dtype=global_k.dtype,
            device=global_k.device,
        )
        self.init_exc = False
        self.dtype = local_q.dtype
        self.position_embedding._update_cos_sin_tables_len(
            self.n_local + self.exc_block_size + 1, local_k.device, local_k.dim()
        )

        buffer_len = (
            self.topk * self.block_size
            + self.exc_block_size
            + self.block_size
            + self.n_init
        )
        self.global_buffer = torch.zeros(
            (2, self.num_units, self.unit_size_kv, buffer_len, dim_head),
            dtype=global_k.dtype,
            device=global_k.device,
        )
        self.global_buffer_block_id_list = [
            [-1] * self.topk for _ in range(self.num_units)
        ]
        self.global_buffer_init_st = 0
        self.global_buffer_init_ed = 0
        self.cuda_cache = CudaCache(
            self.max_cached_block * self.num_units,
            self.unit_size_kv * self.block_size * dim_head * 2,
            local_k.dtype,
        )

        self.initialized = True

    def calc_block_topk(self, global_h_q):
        if not self._use_chunk_topk:
            if self.num_global_block <= self.topk:
                return [
                    list(range(len(self.global_blocks[0])))
                    for _ in range(self.num_units)
                ]

            global_h_q = global_h_q.mean(dim=2, keepdim=False)
            assert global_h_q.shape == (self.num_units, self.unit_size, self.dim_head)
            global_h_q = global_h_q.reshape(
                self.num_units, self.dim_head * self.unit_size
            )
            ret = []
            for u in range(self.num_units):
                ret.append(self.block_k[u].get_topk(global_h_q[u], self.topk))

        else:
            return self._cached_topk[self._topk_cur]

        return ret

    def get_global_hidden_and_mask(self, len_q, block_topk):
        assert len(block_topk) == self.num_units
        global_block_map = [[] for _ in range(self.num_units)]
        global_remainder_len = max(
            self._global_remainder_ed
            - self._global_remainder_st
            + len_q
            - self.n_local,
            0,
        )
        init_len = self.init_k.size(-2)
        sliding_window = None

        global_h_k = self.global_buffer[0]
        global_h_v = self.global_buffer[1]

        block_num = len(block_topk[0])
        for u in range(self.num_units):
            assert len(block_topk[u]) == block_num

            block_topk[u].sort()
            global_block_map[u] = deepcopy(self.global_buffer_block_id_list[u])
            for b_idx in block_topk[u]:
                if b_idx in global_block_map[u]:
                    continue

                st = -1
                ed = -1
                for j in range(self.topk):
                    if (
                        global_block_map[u][j] == -1
                        or global_block_map[u][j] not in block_topk[u]
                    ):
                        st = j * self.block_size
                        ed = st + self.block_size
                        global_block_map[u][j] = b_idx
                        break

                assert b_idx in self.cached_blocks[u]
                self.global_blocks[u][b_idx].load(
                    (global_h_k[u, :, st:ed, :], global_h_v[u, :, st:ed, :])
                )

        init_st = block_num * self.block_size
        init_ed = init_st + init_len
        if (
            self.global_buffer_init_st != init_st
            or self.global_buffer_init_ed != init_ed
        ):
            global_h_k[:, :, init_st:init_ed, :].copy_(self.init_k, non_blocking=True)
            global_h_v[:, :, init_st:init_ed, :].copy_(self.init_v, non_blocking=True)

        ed = init_ed

        rmd_st = init_ed
        rmd_ed = rmd_st + global_remainder_len
        ed = rmd_ed
        global_h_k[:, :, rmd_st:rmd_ed, :].copy_(
            self.global_remainder[0][
                :,
                :,
                self._global_remainder_st : self._global_remainder_st
                + global_remainder_len,
                :,
            ],
            non_blocking=True,
        )
        global_h_v[:, :, rmd_st:rmd_ed, :].copy_(
            self.global_remainder[1][
                :,
                :,
                self._global_remainder_st : self._global_remainder_st
                + global_remainder_len,
                :,
            ],
            non_blocking=True,
        )

        sliding_window = (self.global_remainder[0].size(-2) + rmd_st, self.n_local)

        self.global_buffer_block_id_list = deepcopy(global_block_map)
        self.global_buffer_init_st = init_st
        self.global_buffer_init_ed = init_ed

        for u in range(self.num_units):
            assert max(global_block_map[u][block_num:] + [-1]) == -1
            assert min(global_block_map[u][:block_num] + [0]) > -1
            global_block_map[u] = list(global_block_map[u][:block_num])

        global_h_k = global_h_k[:, :, :ed, :]
        global_h_v = global_h_v[:, :, :ed, :]
        return global_h_k, global_h_v, sliding_window, global_block_map, block_num

    def update_block_score(
        self, global_score: torch.FloatTensor, global_block_map, global_block_num
    ):
        if global_score is not None:
            global_score = global_score[:, :, : global_block_num * self.block_size]
            assert global_score.shape == (
                self.num_units,
                self.unit_size,
                global_block_num * self.block_size,
            )
            global_score = global_score.view(
                self.num_units, self.unit_size, global_block_num, self.block_size
            )
            global_score = global_score.sum(dim=-1).sum(dim=1)
            assert global_score.shape == (self.num_units, global_block_num)
            global_score = global_score.to(
                device="cpu", non_blocking=False
            )  # (num_units, global_block_num)
            for u in range(self.num_units):
                for k, v in self.cached_blocks[u].items():
                    self.cached_blocks[u][k] = v * self.score_decay
                score = global_score[u].tolist()
                assert len(score) >= len(global_block_map[u])
                for s, i in zip(score, global_block_map[u]):
                    self.cached_blocks[u][i] += s

    def _append(self, local_q, local_k, local_v, global_q):
        # get local_h_q, local_h_k, local_h_v
        local_h_q, local_h_k = self.position_embedding(local_q, local_k)
        local_h_v = local_v

        # calc local result first to overlap host-device communication
        attn = self.Attn(local_h_q.shape, local_h_q.dtype, local_h_q.device)
        attn.append(
            local_h_q, local_h_k, local_h_v, get_score=True, sliding_window=self.n_local
        )

        # calc topk global repr k and load cache
        with torch.cuda.stream(GLOBAL_STREAM):
            block_topk = self.calc_block_topk(global_q)

            for u in range(self.num_units):
                num_remove = len(self.cached_blocks[u]) - self.max_cached_block
                for bidx in block_topk[u]:
                    if bidx not in self.cached_blocks[u]:
                        num_remove += 1

                # update cache
                self.remove_lru_blocks(u, num_remove, block_topk[u])

            if self.cache_strategy == "lru":
                self.load_count += 1
                for u in range(self.num_units):
                    for bidx in block_topk[u]:
                        self.cached_blocks[u][bidx] = self.load_count

            elif self.cache_strategy == "lru-s":
                for u in range(self.num_units):
                    for bidx in block_topk[u]:
                        self.cached_blocks[u][bidx] = 0
            else:
                raise ValueError

            # get global_h_k, global_h_v, global_mask
            #    Beacuse exc_block_size <= n_local, no global_k, global_v used in global part
            global_h_q = global_q
            (
                global_h_k,
                global_h_v,
                global_sliding_window,
                global_block_map,
                global_block_num,
            ) = self.get_global_hidden_and_mask(local_h_q.size(-2), block_topk)

        if self.async_global_stream:
            torch.cuda.current_stream().wait_stream(GLOBAL_STREAM)

        # calc global result
        attn.append(
            global_h_q,
            global_h_k,
            global_h_v,
            end=True,
            get_score=self.calc_block_score,
            sliding_window=global_sliding_window,
            complement_sliding_window=True,
        )

        o, score_list = attn.get_result()
        loc_score = score_list[0]
        glb_score = score_list[1]

        if self.async_global_stream:
            GLOBAL_STREAM.wait_stream(torch.cuda.current_stream())

        # update global score
        with torch.cuda.stream(GLOBAL_STREAM):
            self.update_block_score(glb_score, global_block_map, global_block_num)

        return o.view((self.batch_size, self.num_heads, -1, self.dim_head)), loc_score

    def get_batched_topk(self, global_q):
        length = global_q.shape[2]
        exc_num = (length + self.exc_block_size - 1) // self.exc_block_size
        exc_block_num = length // self.exc_block_size
        ret = []
        if self.num_global_block <= self.topk:
            for _ in range(exc_num):
                ret.append(
                    [
                        list(range(len(self.global_blocks[0])))
                        for _ in range(self.num_units)
                    ]
                )
            return ret

        global_h_q = global_q
        assert global_h_q.dim() == 4
        assert global_h_q.shape[:2] == (self.num_units, self.unit_size)
        assert global_h_q.shape[3] == self.dim_head

        block_k = torch.cat(
            [self.block_k[u].get_data()[None, :, :] for u in range(self.num_units)],
            dim=0,
        )
        assert block_k.shape == (
            self.num_units,
            self.num_global_block,
            self.dim_head * self.unit_size,
        )
        block_k = (
            block_k.reshape(
                self.num_units, self.num_global_block, self.unit_size, self.dim_head
            )
            .permute(0, 2, 1, 3)
            .contiguous()
        )

        if exc_block_num > 0:
            tmp_global_h_q = (
                global_h_q[:, :, : exc_block_num * self.exc_block_size, :]
                .reshape(
                    self.num_units,
                    self.unit_size,
                    exc_block_num,
                    self.exc_block_size,
                    self.dim_head,
                )
                .mean(dim=-2)
            )
            assert tmp_global_h_q.shape == (
                self.num_units,
                self.unit_size,
                exc_block_num,
                self.dim_head,
            )
            block_score = torch.matmul(tmp_global_h_q, block_k.transpose(-1, -2)).mean(
                dim=1
            )  # (num_units, exc_block_num, num_global_block)
            assert block_score.shape == (
                self.num_units,
                exc_block_num,
                self.num_global_block,
            )

            indices = block_score.topk(self.topk, dim=-1).indices.cpu()
            for b in range(exc_block_num):
                tmp = []
                for u in range(self.num_units):
                    tmp.append(indices[u, b].tolist())
                    assert len(tmp[-1]) == self.topk

                ret.append(tmp)

        if exc_block_num != exc_num:
            tmp_global_h_q = (
                global_h_q[:, :, exc_block_num * self.exc_block_size :, :]
                .reshape(
                    self.num_units,
                    self.unit_size,
                    length - exc_block_num * self.exc_block_size,
                    self.dim_head,
                )
                .mean(dim=-2, keepdim=True)
            )
            assert tmp_global_h_q.shape == (
                self.num_units,
                self.unit_size,
                1,
                self.dim_head,
            )
            block_score = torch.matmul(tmp_global_h_q, block_k.transpose(-1, -2))
            assert block_score.shape == (
                self.num_units,
                self.unit_size,
                1,
                self.num_global_block,
            )
            block_score = block_score.squeeze(dim=2).mean(dim=1)
            assert block_score.shape == (self.num_units, self.num_global_block)
            indices = block_score.topk(self.topk, dim=-1).indices.cpu()
            tmp = []
            for u in range(self.num_units):
                tmp.append(indices[u].tolist())
                assert len(tmp[-1]) == self.topk

            ret.append(tmp)

        return ret

    def append_global(self, exc_length, kv_length, local_score):
        global_remainder_ed = self._global_remainder_ed + exc_length
        global_remainder_st = self._global_remainder_st

        global_remainder_len = global_remainder_ed - global_remainder_st

        assert local_score.shape[:3] == (self.num_units, self.unit_size, kv_length)
        local_score = local_score[:, :, -exc_length - self.n_local :]
        self.global_remainder_local_score[
            :, :, global_remainder_ed - local_score.size(-1) : global_remainder_ed
        ].add_(local_score)

        if not self.init_exc and global_remainder_len > self.n_local:
            global_k = self.global_remainder[0]
            global_v = self.global_remainder[1]

            append_init_len = min(
                self.n_init - self.init_k.size(-2), global_remainder_len - self.n_local
            )
            self.init_k = torch.cat(
                (
                    self.init_k,
                    global_k[
                        :,
                        :,
                        global_remainder_st : global_remainder_st + append_init_len,
                        :,
                    ],
                ),
                dim=-2,
            )
            self.init_v = torch.cat(
                (
                    self.init_v,
                    global_v[
                        :,
                        :,
                        global_remainder_st : global_remainder_st + append_init_len,
                        :,
                    ],
                ),
                dim=-2,
            )
            global_remainder_st += append_init_len
            global_remainder_len -= append_init_len

            if self.init_k.size(-2) == self.n_init:
                self.init_exc = True

        while global_remainder_len - self.block_size >= self.n_local:
            global_remainder_len -= self.block_size
            for u in range(self.num_units):
                self.global_blocks[u].append(
                    (
                        MemoryUnit(
                            (
                                self.global_remainder[0][
                                    u,
                                    :,
                                    global_remainder_st : global_remainder_st
                                    + self.block_size,
                                    :,
                                ],
                                self.global_remainder[1][
                                    u,
                                    :,
                                    global_remainder_st : global_remainder_st
                                    + self.block_size,
                                    :,
                                ],
                            ),
                            self.cuda_cache,
                            False,
                            self.pin_memory,
                        )
                    )
                )

            global_block_k = self.get_block_k(
                self.global_remainder[0][
                    :, :, global_remainder_st : global_remainder_st + self.block_size, :
                ],
                self.global_remainder_local_score[
                    :, :, global_remainder_st : global_remainder_st + self.block_size
                ],
            )
            assert global_block_k.shape == (
                self.num_units,
                self.unit_size,
                self.repr_topk,
                self.dim_head,
            )
            global_block_k = global_block_k.mean(dim=-2, keepdim=False)
            global_block_k = global_block_k.reshape(
                self.num_units, self.unit_size * self.dim_head
            )
            global_block_k = global_block_k[:, None, :]

            self.num_global_block += 1
            for u in range(self.num_units):
                self.block_k[u].append(global_block_k[u])
            global_remainder_st += self.block_size

        self._global_remainder_ed = global_remainder_ed
        self._global_remainder_st = global_remainder_st

    def append(
        self,
        local_q,
        local_k,
        local_v,
        global_q,
        global_k,
        global_v,
    ):
        batch_size = local_q.size(0)
        input_length = local_q.size(-2)

        if self.perhead:
            num_heads = local_q.size(1)
            num_heads_kv = local_v.size(1)

            def repeat_kv(t):
                t = t.view(batch_size, num_heads_kv, 1, input_length, -1)
                t = t.expand(
                    batch_size,
                    num_heads_kv,
                    num_heads // num_heads_kv,
                    input_length,
                    -1,
                )
                t = t.reshape(batch_size * num_heads, 1, input_length, -1)
                return t

            local_q = local_q.view(batch_size * num_heads, 1, input_length, -1)
            local_k = repeat_kv(local_k)
            local_v = repeat_kv(local_v)
            global_q = global_q.view(batch_size * num_heads, 1, input_length, -1)
            global_k = repeat_kv(global_k)
            global_v = repeat_kv(global_v)

        if not self.initialized:
            self.init(local_q, local_k, local_v, global_q, global_k, global_v)

        input_length = local_q.size(-2)

        if self.async_global_stream:
            GLOBAL_STREAM.wait_stream(torch.cuda.current_stream())

        # append local and global tensor
        self.local_k = torch.cat((self.local_k, local_k), dim=-2)
        self.local_v = torch.cat((self.local_v, local_v), dim=-2)
        kv_length = self.local_k.size(-2)

        if self.dense_decoding:
            self.dense_k = torch.cat((self.dense_k, local_k), dim=-2)
            self.dense_v = torch.cat((self.dense_v, local_v), dim=-2)

        # append global remainder
        with torch.cuda.stream(GLOBAL_STREAM):
            self._global_remainder_st = 0
            self._global_remainder_ed = self.global_remainder[0].size(-2)

            self.global_remainder = (
                torch.cat((self.global_remainder[0], global_k), dim=-2),
                torch.cat((self.global_remainder[1], global_v), dim=-2),
            )

            self.global_remainder_local_score = torch.cat(
                (
                    self.global_remainder_local_score,
                    torch.zeros(
                        (self.num_units, self.unit_size, global_k.size(-2)),
                        dtype=global_k.dtype,
                        device=global_k.device,
                    ),
                ),
                dim=-1,
            )

        with torch.cuda.stream(GLOBAL_STREAM):
            global_q = self.position_embedding.apply_rotary_pos_emb_one_angle(
                global_q, self.n_local
            )

        use_chunk_topk = self.chunk_topk_calc is not None and input_length > 1
        self._use_chunk_topk = use_chunk_topk
        if use_chunk_topk:
            exc_block_num = input_length // self.exc_block_size
            exc_block_per_topk_chunk = self.chunk_topk_calc // self.exc_block_size
            calc_cur_list = [
                i * self.exc_block_size
                for i in range(0, exc_block_num + 1, exc_block_per_topk_chunk)
            ]
            if calc_cur_list[-1] < input_length:
                calc_cur_list.append(input_length)
            self._topk_cur = 0
            self._topk_calc_cur = -1

        o_list = []

        for st in range(0, input_length, self.exc_block_size):
            ed = min(st + self.exc_block_size, input_length)
            if use_chunk_topk and calc_cur_list[self._topk_calc_cur + 1] < ed:
                # calculate topk and sync with host here
                assert ed <= calc_cur_list[self._topk_calc_cur + 2]
                self._topk_calc_cur += 1
                with torch.cuda.stream(GLOBAL_STREAM):
                    self._cached_topk = self.get_batched_topk(
                        global_q[
                            :,
                            :,
                            calc_cur_list[self._topk_calc_cur] : calc_cur_list[
                                self._topk_calc_cur + 1
                            ],
                            :,
                        ]
                    )
                self._topk_cur = 0

            kv_st = max(kv_length + st - input_length - self.n_local, 0)
            kv_ed = kv_length + ed - input_length
            chunk_o, local_score = self._append(
                local_q[:, :, st:ed, :],
                self.local_k[:, :, kv_st:kv_ed, :],
                self.local_v[:, :, kv_st:kv_ed, :],
                global_q[:, :, st:ed, :],
            )
            o_list.append(chunk_o)

            # append global
            with torch.cuda.stream(GLOBAL_STREAM):
                self.append_global(ed - st, kv_ed - kv_st, local_score)

            if self.async_global_stream:
                torch.cuda.current_stream().wait_stream(GLOBAL_STREAM)

            if use_chunk_topk:
                self._topk_cur += 1

        self.length += input_length

        # update local and global tensor
        if self.local_k.size(-2) >= self.n_local:
            self.local_k = self.local_k[:, :, -self.n_local :, :]
            self.local_v = self.local_v[:, :, -self.n_local :, :]

        assert self._global_remainder_ed == self.global_remainder[0].size(-2)
        with torch.cuda.stream(GLOBAL_STREAM):
            self.global_remainder = (
                self.global_remainder[0][:, :, self._global_remainder_st :, :],
                self.global_remainder[1][:, :, self._global_remainder_st :, :],
            )
            self.global_remainder_local_score = self.global_remainder_local_score[
                :, :, self._global_remainder_st :
            ]

        ret = torch.cat(o_list, dim=-2)

        if self.perhead:
            ret = ret.view(batch_size, num_heads, input_length, -1)

        return ret

    def size(self, *args, **kwargs):
        return self.length


def inf_llm_forward(
    n_local,
    n_init,
    topk,
    block_size,
    max_cached_block,
    exc_block_size,
    repr_topk: int = 1,
    cache_strategy="lru",
    score_decay=None,
    chunk_topk_calc=None,
    async_global_stream=True,
    pin_memory=False,
    faiss=False,
    perhead=False,
    dense_decoding=False,
    *args,
    **kwargs
):
    def forward(
        self,
        query: torch.Tensor,
        key_value: torch.Tensor,
        position_bias: Optional[torch.Tensor],
        use_cache: bool,
        past_key_value,
        project_q,
        project_k,
        project_v,
        attention_out,
        dim_head,
        num_heads,
        num_heads_kv,
    ):
        batch_size = query.size(0)
        len_q = query.size(1)
        len_k = key_value.size(1)

        # assert use_cache

        h_q = project_q(query)  # (batch, len_q, num_heads * dim_head)
        h_k = project_k(key_value)  # (batch, len_k, num_heads * dim_head)
        h_v = project_v(key_value)  # (batch, len_k, num_heads * dim_head)

        h_q = (
            h_q.view(batch_size, len_q, num_heads, dim_head)
            .permute(0, 2, 1, 3)
            .contiguous()
        )  # (batch, num_heads, len_q, dim_head)
        h_k = (
            h_k.view(batch_size, len_k, num_heads_kv, dim_head)
            .permute(0, 2, 1, 3)
            .contiguous()
        )  # (batch, num_heads_kv, len_k, dim_head)
        h_v = (
            h_v.view(batch_size, len_k, num_heads_kv, dim_head)
            .permute(0, 2, 1, 3)
            .contiguous()
        )  # (batch, num_heads_kv, len_k, dim_head)

        if len_q == 1 and dense_decoding:
            past_k = past_key_value.dense_k
            past_v = past_key_value.dense_v

            h_k = torch.cat((past_k, h_k), dim=-2)
            h_v = torch.cat((past_v, h_v), dim=-2)

            past_key_value.dense_k = h_k
            past_key_value.dense_v = h_v

            h_q, h_k = position_bias(h_q, h_k)

            # (batch_size, seqlen, nheads, headdim)
            h_q = h_q.transpose(1, 2)
            h_k = h_k.transpose(1, 2)
            h_v = h_v.transpose(1, 2)

            # (batch_size, seqlen, nheads, headdim)
            o = flash_attn_func(h_q, h_k, h_v, causal=True)

            o = o.reshape(batch_size, len_q, dim_head * num_heads)
            o = attention_out(o)

            if use_cache:
                return o, past_key_value
            else:
                return o

        if past_key_value is None:
            past_key_value = ContextManager(
                position_bias,
                n_init,
                n_local,
                block_size,
                max_cached_block,
                topk,
                exc_block_size,
                score_decay,
                repr_topk,
                cache_strategy,
                chunk_topk_calc,
                async_global_stream,
                pin_memory,
                faiss,
                perhead,
                dense_decoding=dense_decoding,
            )

        local_q, local_k, local_v = h_q, h_k, h_v
        global_q, global_k, global_v = h_q, h_k, h_v

        o = past_key_value.append(
            local_q,
            local_k,
            local_v,
            global_q,
            global_k,
            global_v,
        )

        o = o.view(batch_size, num_heads, len_q, dim_head).permute(0, 2, 1, 3)
        o = o.reshape(batch_size, len_q, dim_head * num_heads)
        o = attention_out(o)

        if use_cache:
            return o, past_key_value
        else:
            return o

    return forward


class GreedySearch:
    def __init__(self, model, tokenizer):
        model.eval()
        self.device = model.device
        self.model = model
        self.tokenizer = tokenizer
        self.past_kv = None

    def clear(self):
        self.past_kv = None

    def _process_texts(self, input_text):
        model_inputs = {}
        input_ids = self.tokenizer.encode(input_text)

        model_inputs["input_ids"] = input_ids
        model_inputs["attention_mask"] = [1] * len(model_inputs["input_ids"])

        for key in model_inputs:
            model_inputs[key] = (
                torch.tensor(model_inputs[key]).int().unsqueeze(0).cuda()
            )

        return model_inputs

    def generate(self, text=None, input_ids=None, **kwargs):
        if input_ids is None:
            model_inputs = self._process_texts(text)
            input_ids = model_inputs["input_ids"]

        with torch.inference_mode():
            result = self._decode(input_ids, **kwargs)

        self.clear()
        return result

    def _decode(
        self,
        input_ids,
        max_length=100,
        extra_end_token_ids=[],
        chunk_size: int = 4096,
        output=False,
    ):
        if input_ids.dim() == 1:
            input_ids = input_ids[None, :]
        input_ids = input_ids.cuda()
        attention_mask = torch.ones_like(input_ids)
        assert input_ids.size(0) == 1
        length = input_ids.size(1)
        end_token_ids = extra_end_token_ids + [self.tokenizer.eos_token_id]
        logits = None
        past_key_values = self.past_kv
        if output:
            output_text = ""

        for i in range(max_length + 1):
            if i == 0:
                if chunk_size is None:
                    chunk_size = input_ids.size(1)
                for st in range(0, input_ids.size(1) - 1, chunk_size):
                    ed = min(input_ids.size(1) - 1, st + chunk_size)
                    out = self.model(
                        input_ids=input_ids[:, st:ed],
                        attention_mask=attention_mask[:, :ed],
                        use_cache=True,
                        return_dict=True,
                        past_key_values=past_key_values,
                    )
                    logits, past_key_values = out.logits, out.past_key_values

                out = self.model(
                    input_ids=input_ids[:, -1:],
                    attention_mask=attention_mask,
                    use_cache=True,
                    return_dict=True,
                    past_key_values=past_key_values,
                )
                logits, past_key_values = out.logits, out.past_key_values
            else:
                out = self.model(
                    input_ids=input_ids[:, -1:],
                    attention_mask=attention_mask,
                    past_key_values=past_key_values,
                    use_cache=True,
                    return_dict=True,
                )
                logits, past_key_values = out.logits, out.past_key_values

            logits = logits[:, -1, :]
            word = logits.argmax(dim=-1)
            if word.item() in end_token_ids or i == max_length:
                break

            input_ids = torch.cat((input_ids, word.view(1, 1)), dim=-1)
            attention_mask = torch.cat(
                (
                    attention_mask,
                    torch.ones(
                        (attention_mask.size(0), 1),
                        dtype=torch.int,
                        device=attention_mask.device,
                    ),
                ),
                dim=-1,
            )
            if output:
                tmp = self.tokenizer.decode(input_ids.squeeze(0)[length:])
                if len(tmp) > len(output_text):
                    import sys

                    sys.stdout.write(tmp[len(output_text) :])
                    sys.stdout.flush()
                    output_text = tmp

        self.past_kv = past_key_values

        if output:
            sys.stdout.write("\n")
            sys.stdout.flush()

        # return [self.tokenizer.decode(input_ids.squeeze(0)[length:])]
        return input_ids


class InfLLMGenerator(GreedySearch):
    def generate(
        self,
        input_ids=None,
        generation_config=None,
        pad_token_id=None,
        max_new_tokens=None,
    ):
        if max_new_tokens is not None:
            max_new_tokens = max_new_tokens
        else:
            max_new_tokens = generation_config.max_new_tokens
        return super().generate(
            text=None,
            input_ids=input_ids,
            max_length=max_new_tokens,
            chunk_size=8192,
            extra_end_token_ids=[pad_token_id] if pad_token_id is not None else [],
        )

    @torch.no_grad()
    def __call__(self, input_ids=None, *args, **kwargs):
        # chunked forward
        chunk_size = 8192
        all_logits = torch.empty(0, dtype=torch.bfloat16).to(input_ids.device)
        for st in range(0, input_ids.size(1), chunk_size):
            torch.cuda.empty_cache()
            ed = min(input_ids.size(1), st + chunk_size)
            out = self.model(
                input_ids=input_ids[:, st:ed],
            )
            logits = out.logits.to(torch.bfloat16)
            all_logits = torch.cat((all_logits, logits), dim=1)

        return CausalLMOutput(logits=all_logits)