from collections import deque

import faiss
import torch
import torch.nn.functional as F
import numpy as np
from torch import nn


class KNN:
    """
    KNN for one element in batch. Handles all heads
    """

    def __init__(self, num_heads, head_dim, memories_size=16000, shrink_size=None, cache=None):
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.memories_size = memories_size
        self.shrink_size = shrink_size or memories_size * 1.1
        self.indexes = [faiss.IndexFlat(
            self.head_dim, faiss.METRIC_INNER_PRODUCT) for _ in range(self.num_heads)]
        self.values = [deque([]) for _ in range(self.num_heads)]
        self.cache = cache

    def __del__(self):
        if hasattr(self, 'indexes'):
            del self.indexes
            del self.values

    def clear(self):
        for index in self.indexes:
            index.reset()

        for value in self.values:
            value.clear()

    def shrink(self):
        """Shrinks index to memories_size"""

        for i, index in enumerate(self.indexes):
            if index.ntotal > self.shrink_size:
                to_delete = index.ntotal - self.memories_size
                index.remove_ids(np.arange(0, to_delete))

                for _ in range(to_delete):
                    self.values[i].popleft()

    def add(self, key, value):
        for i, k in enumerate(key):
            self.indexes[i].add(k)
        for i, v in enumerate(value):
            self.values[i].extend(v)

        if self.cache is not None:
            raise RuntimeError("Cache for KNN not implemented")
            # self.cache.add(key)

        self.shrink()

    def search(self, query, k=32):
        """
        Searchs for query in keys' index.
        Returns k most relevant keys and corresponding values
        """

        k = min(k, len(self.values[0]))

        if k <= 0:
            return torch.empty((query.shape[0], query.shape[1], 0, query.shape[2])),\
                torch.empty(
                    (query.shape[0], query.shape[1], 0, query.shape[2]))

        Ks, Vs = [], []

        for i, q in enumerate(query):
            D, I, K = self.indexes[i].search_and_reconstruct(q, k=k)
            V = np.take(self.values[i], indices=I, axis=0)
            Ks.append(K)
            Vs.append(V)

        return np.stack(Ks, axis=0), np.stack(Vs, axis=0)


class KNNLayer:
    """
    KNN Attention layer. Handles KNN's for batch (every elemnt separately)
    """

    def __init__(self, config, share_memory=True, batch_size=None, memory_size=16000, shrink_size=None, n_jobs=4, cache=None):
        if not share_memory and batch_size is None:
            raise RuntimeError(
                "If share_memory is False, batch_size should be passed")

        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        self.head_dim = self.embed_dim // self.num_heads

        self.share_memory = share_memory
        self.batch_size = batch_size
        self.memory_size = memory_size
        self.shrink_size = shrink_size or self.memory_size * 1.1
        self.closed = False

        if not share_memory:
            self.knns = [KNN(self.num_heads, self.head_dim, memory_size,
                             self.shrink_size, cache=cache) for _ in range(self.batch_size)]
        else:
            self.knn = KNN(self.num_heads, self.head_dim,
                           memory_size, self.shrink_size, cache=cache)

        faiss.omp_set_num_threads(n_jobs)

    def clear_batches(self, batch_indexes):
        if self.closed:
            return

        if not self.share_memory:
            for idx in batch_indexes:
                self.knns[idx].clear()

    def clear(self):
        if self.closed:
            return

        if self.share_memory:
            self.knn.clear()
        else:
            for idx in range(len(self.knns)):
                self.knns[idx].clear()

    def add(self, keys, values):
        if self.closed:
            return

        keys, values = keys.numpy(force=True), values.numpy(force=True)
        if not self.share_memory:
            for i, (key, value) in enumerate(zip(keys, values)):
                self.knns[i].add(key, value)
        else:
            for key, value in zip(keys, values):
                self.knn.add(key, value)

    def search(self, queries, k=32):
        queries = queries.numpy(force=True)
        keys, values = [], []
        max_len = 0

        if self.share_memory:
            for query in queries:
                key, value = self.knn.search(query, k)
                keys.append(key)
                values.append(value)
                max_len = max(max_len, key.shape[2])
        else:
            for i, query in enumerate(queries):
                key, value = self.knns[i].search(query, k)
                keys.append(key)
                values.append(value)
                max_len = max(max_len, key.shape[2])

        masks = np.ones((len(keys), max_len), dtype=np.float32)

        for i, (key, value) in enumerate(zip(keys, values)):
            l = key.shape[2]

            if l == max_len:
                continue
            elif l > max_len:
                raise RuntimeError("What? max_len is not max")

            sh = list(key.shape)
            sh[2] = max_len - sh[2]
            keys[i] = np.concatenate(
                (key, np.zeros(sh, dtype=np.float32)), axis=2)
            values[i] = np.concatenate(
                (value, np.zeros(sh, dtype=np.float32)), axis=2)
            masks[i, l:] = 0

        return torch.from_numpy(np.stack(keys, axis=0)),\
            torch.from_numpy(np.stack(values, axis=0)),\
            torch.from_numpy(masks)

    def close(self):
        self.closed = True

    def open(self):
        self.closed = False

    def reset(self):
        self.open()
        self.clear()


class ClearMemoryLayer(nn.Module):
    def __init__(self, knn_memory, bos_token, eos_token, next_layer):
        super().__init__()

        self.knn_memory = knn_memory
        self.bos_token = bos_token
        self.eos_token = eos_token
        self.next_layer = next_layer

    def _clear_if_token(self, tokens, token):
        batches_to_clear = (tokens == token).any(dim=-1).nonzero()

        if len(batches_to_clear) > 0:
            self.knn_memory.clear_batches(batches_to_clear[0])

    def forward(self, tokens, *args, **kwargs):
        # self._clear_if_token(tokens, self.bos_token)

        batches_to_clear = (tokens[:, 0] == self.bos_token).nonzero()

        if len(batches_to_clear) > 0:
            self.knn_memory.clear_batches(batches_to_clear[:, 0])

        res = self.next_layer(tokens, *args, **kwargs)
        # self._clear_if_token(tokens, self.eos_token)

        return res