import copy
import os
from datetime import timedelta
import sys
from time import time
from pathlib import Path
from typing import List, Literal, Optional, Tuple, Union

import torch
import torch.nn.functional as F
import transformers
from accelerate import (
    Accelerator,
    DistributedType,
    InitProcessGroupKwargs,
    find_executable_batch_size,
)
from packaging import version
from peft import PeftModel
from peft import __version__ as PEFT_VERSION
from tqdm import tqdm
from transformers.models.auto.modeling_auto import (
    MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
    MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
)
from transformers import TextStreamer
from transformers.models.dbrx.modeling_dbrx import DbrxExpertGLU
from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM
from lm_eval.api.registry import register_model
from lm_eval.models.utils import (
    Collator,
    clear_torch_cache,
    get_dtype,
    pad_and_concat,
    stop_sequences_criteria,
)
from lm_eval.models.huggingface import HFLM
from src.utils import get_gpu_details, get_peak_bw, transfer_precision2bytes, get_peak_flops
from src.submission.check_validity import get_model_size
from src.envs import API


class StopWatch(TextStreamer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.start_prefilling = None
        self.prefilling_time = None
        self.start_decoding = None
        self.decoding_time = None
        self.decoding_iterations = 0

    def put(self, value):
        if self.start_prefilling is None:
            self.start_prefilling = time()
            return
        elif self.prefilling_time is None:
            self.prefilling_time = time() - self.start_prefilling
            self.start_decoding = time()
        self.decoding_iterations += 1
        return

    def end(self):
        if self.decoding_time is None and self.start_decoding is not None:
            self.decoding_time = time() - self.start_decoding
        return


class HFLMWithMeasurement(HFLM):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        self.pretrained = kwargs.get("pretrained", None)
        self.revision = kwargs.get("revision", None)
        self.precision = kwargs.get("dtype", None)
        self.num_gpus = None

    def _detect_num_gpus_used(self):
        if self.num_gpus is not None:
            return self.num_gpus
        gpus = []
        for p in self.model.parameters():
            if p.device.type == "cuda":
                gpus.append(p.device.index)
                
        self.num_gpus = len(set(gpus))
        return self.num_gpus

    def _loglikelihood_tokens(
        self,
        requests: List[Tuple[Tuple[str, str], List[int], List[int]]],
        disable_tqdm: bool = False,
        override_bs: int = None,
    ) -> List[Tuple[float, bool]]:
        # TODO: implement some kind of efficient-request-middleware that lumps together requests with the same context
        res = []

        def _collate(req: Tuple[Tuple[str, str], List[int], List[int]]):
            """Defines the key for the sorted method"""
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end

            toks = req[1] + req[2]
            return -len(toks), tuple(toks)

        def _lookup_one_token_cont(req: Tuple[Tuple[str, str], List[int], List[int]]):
            """Defines the key to group and lookup one-token continuations"""
            # Use with group_by="contexts" (optional)"
            # allows for the creation of a lookup, so we can reuse logits in case of one-token continuations.
            # speeds up some multiple-choice tasks proportionally to the number of choices.
            # groups requests by context+continuation[:-1] and infer on one request/group.
            return req[-2] + req[-1][:-1]

        re_ord = Collator(
            requests,
            sort_fn=_collate,
            group_by="contexts"
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
            and self.logits_cache
            else None,
            group_fn=_lookup_one_token_cont,
        )

        # automatic (variable) batch size detection for vectorization
        # pull longest context sample from request
        n_reordered_requests = len(re_ord)
        batch_size = (
            self.batch_size
            if self.batch_size != "auto"
            else override_bs
            if override_bs is not None
            else 0
        )
        batch_fn = (
            self._batch_scheduler
            if self.batch_size == "auto"
            and n_reordered_requests > 0
            and not override_bs
            else None
        )

        chunks = re_ord.get_batched(n=batch_size, batch_fn=batch_fn)
        pbar = tqdm(
            total=len(requests),
            disable=(disable_tqdm or (self.rank != 0)),
            desc="Running loglikelihood requests",
        )
        for chunk in chunks:
            inps = []
            cont_toks_list = []
            inplens = []

            conts = []
            encoder_attns = []

            padding_len_inp = None
            padding_len_cont = None
            # because vectorizing is annoying, we first convert each (context, continuation) pair to padded
            # tensors, then we pack them together into a batch, call the model, and then pick it all apart
            # again because vectorizing is annoying

            for _, context_enc, continuation_enc in chunk:
                # sanity check
                assert len(context_enc) > 0
                assert len(continuation_enc) > 0
                assert len(continuation_enc) <= self.max_length

                # how this all works (illustrated on a causal decoder-only setup):
                #          CTX      CONT
                # inp    0 1 2 3|4 5 6 7 8 9   <- last token is deleted by inp[:, :-1]
                # model  \               \
                # logits   1 2 3|4 5 6 7 8 9   <- the ctx half gets tossed out by the
                # cont_toks      4 5 6 7 8 9      [:, -len(continuation_enc):, :self.vocab_size] slice

                # when too long to fit in context, truncate from the left
                if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                    inp = torch.tensor(
                        (context_enc + continuation_enc)[-(self.max_length + 1) :][:-1],
                        dtype=torch.long,
                        device=self.device,
                    )
                    (inplen,) = inp.shape
                elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                    inp = torch.tensor(
                        (context_enc)[-self.max_length :],
                        dtype=torch.long,
                        device=self.device,
                    )
                    (inplen,) = inp.shape

                    # build encoder attn masks
                    encoder_attns.append(torch.ones_like(inp))

                    cont = torch.tensor(
                        (continuation_enc)[-self.max_length :],
                        # TODO: left-shift these?
                        # TODO: our code assumes we never end up truncating conts for either model type
                        dtype=torch.long,
                        device=self.device,
                    )
                    (contlen,) = cont.shape

                    conts.append(cont)

                    padding_len_cont = (
                        max(padding_len_cont, contlen)
                        if padding_len_cont is not None
                        else contlen
                    )

                padding_len_inp = (
                    max(padding_len_inp, inplen)
                    if padding_len_inp is not None
                    else inplen
                )

                inps.append(inp)  # [1, inp_length]
                cont_toks_list.append(continuation_enc)
                inplens.append(inplen)

            # create encoder attn mask and batched conts, if seq2seq
            call_kwargs = {}
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                batched_inps = pad_and_concat(
                    padding_len_inp, inps, padding_side="right"
                )  # [batch, padding_len_inp]
            elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                # TODO: left-pad encoder inps and mask?
                batched_inps = pad_and_concat(
                    padding_len_inp, inps
                )  # [batch, padding_len_inp]
                batched_conts = pad_and_concat(
                    padding_len_cont, conts
                )  # [batch, padding_len_cont]
                batched_encoder_mask = pad_and_concat(
                    padding_len_inp, encoder_attns
                )  # [batch, padding_len_inp]
                call_kwargs = {
                    "attn_mask": batched_encoder_mask,
                    "labels": batched_conts,
                }

            start = time()
            intermediate_res = self._model_call(batched_inps, **call_kwargs)
            end = time()
            multi_logits = F.log_softmax(
                intermediate_res , dim=-1
            )  # [batch, padding_length (inp or cont), vocab]
            per_sample_time = (end - start) / len(multi_logits)

            for (request_str, ctx_tokens, _), logits, inplen, cont_toks in zip(
                chunk, multi_logits, inplens, cont_toks_list
            ):
                # Slice to original seq length
                contlen = len(cont_toks)
                # take only logits in the continuation
                # (discard context toks if decoder-only ; discard right-padding)
                # also discards + checks for "virtual tokens" in the causal LM's input window
                # from prompt/prefix tuning tokens, if applicable
                ctx_len = (
                    inplen + (logits.shape[0] - padding_len_inp)
                    if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
                    else None
                )
                logits = self._select_cont_toks(logits, contlen=contlen, inplen=ctx_len)
                logits = logits.unsqueeze(0)  # [1, seq, vocab]

                # Check if per-token argmax is exactly equal to continuation
                greedy_tokens = logits.argmax(dim=-1)

                # check for one-token continuation cache hits.
                # noop in case group_by != "contexts" or no cache hit and returns the
                # original args. Otherwise, expands the logits batch dimension and yields each
                # batch along with matching continuation tokens and prompt strings.
                # logits -> [1, seq, vocab]
                for request_str, cont_toks, logits in re_ord.get_cache(
                    req_str=request_str,
                    cxt_toks=ctx_tokens,
                    cont_toks=cont_toks,
                    logits=logits,
                ):
                    cont_toks = torch.tensor(
                        cont_toks, dtype=torch.long, device=self.device
                    ).unsqueeze(0)  # [1, seq]
                    max_equal = (greedy_tokens == cont_toks).all()

                    # Obtain log-probs at the corresponding continuation token indices
                    # last_token_slice = logits[:, -1, :].squeeze(0).tolist()
                    logits = torch.gather(logits, 2, cont_toks.unsqueeze(-1)).squeeze(
                        -1
                    )  # [1, seq]

                    # Answer: (log prob, is-exact-match)
                    answer = (float(logits.sum()), bool(max_equal))

                    res.append((answer, per_sample_time, 0, 0, 0, 0))

                    self.cache_hook.add_partial("loglikelihood", request_str, answer)
                    pbar.update(1)

        pbar.close()

        return re_ord.get_original(res)

    def _model_generate(self, context, max_tokens, stop, **generation_kwargs):
        # temperature = 0.0 if not set
        # if do_sample is false and temp==0.0:
        # remove temperature, as do_sample=False takes care of this
        # and we don't want a warning from HF
        generation_kwargs["temperature"] = generation_kwargs.get("temperature", 0.0)
        do_sample = generation_kwargs.get("do_sample", None)
        
        # is_gsm8k = generation_kwargs.get("is_gsm8k", False)

        # The temperature has to be a strictly positive float -- if it is 0.0, use greedy decoding strategies
        if generation_kwargs.get("temperature") == 0.0 and do_sample is None:
            generation_kwargs["do_sample"] = do_sample = False

        if do_sample is False and generation_kwargs.get("temperature") == 0.0:
            generation_kwargs.pop("temperature")
        
        # if is_gsm8k:
        #     generation_kwargs.pop("is_gsm8k")
            
        context_length = context.shape[1]

        if self.model.__class__.__name__ == "MoE":
            model_config = self.model.model.config
        else:
            model_config = self.model.config
        
        if not self.precision:
            if model_config.quantization_config._load_in_4bit:
                self.precision = "4bit"
            elif model_config.quantization_config._load_in_8bit:
                self.precision = "8bit"
            else:
                raise ValueError("Unknown precision")
            
        # print(self.model)
        linear_count = 0 
        element_wise_mul = 0
        for name, module in self.model.named_modules():
            if ('layers.0.' in name or "transformer.blocks.0" in name) and ('attn' not in name):
                if 'experts.0.' in name or "ffn.experts" in name:
                    if "linear_v" in name:
                        element_wise_mul = 1
                    if isinstance(module, torch.nn.Linear):
                        # print(name, module)
                        linear_count += 1
                    elif isinstance(module, DbrxExpertGLU):
                        linear_count = 3
                        element_wise_mul = 1
                # elif 'experts' not in name:
                #     if ("gate" not in name and "router" not in name) or "gate_proj" in name:
                #         if "gate_proj" in name:
                #             element_wise_mul = 1
                #         if isinstance(module, torch.nn.Linear):
                #             # print(name, module)
                #             linear_count += 1
                else:
                    continue
        print(f"linear_count: {linear_count}")
        print(f"element_wise_mul: {element_wise_mul}")
        print(f"GPU usage: {self._detect_num_gpus_used()}")

        stopping_criteria = stop_sequences_criteria(
            self.tokenizer, stop, context.shape[1], context.shape[0]
        )
        stop_watch = StopWatch(self.tokenizer)
        start = time()
        res = self.model.generate(
            input_ids=context,
            max_new_tokens=max_tokens,
            stopping_criteria=stopping_criteria,
            pad_token_id=self.tokenizer.pad_token_id,
            use_cache=True,
            streamer=stop_watch,
            **generation_kwargs,
        )
        end = time()

        batch_size = context.shape[0]
        output_length = stop_watch.decoding_iterations
        
        precision_bytes = transfer_precision2bytes(self.precision)
        
        model_size_param = sum(p.numel() for p in self.model.parameters())

        n_layers = model_config.num_hidden_layers if hasattr(model_config, "num_hidden_layers") else \
            (model_config.num_layers if hasattr(model_config, "num_layers") else model_config.n_layers)
        
        d_model = model_config.hidden_size if hasattr(model_config, "hidden_size") else model_config.d_model
        
        if hasattr(model_config, "num_experts_per_tok"):
            n_experts_per_tok = model_config.num_experts_per_tok
        elif hasattr(model_config, "num_selected_experts"):
            n_experts_per_tok = model_config.num_selected_experts
        elif hasattr(model_config, "ffn_config"):
            n_experts_per_tok = model_config.ffn_config.moe_top_k
        else:
            n_experts_per_tok = 1
        
        if hasattr(model_config, "ffn_dim"):
            d_ff = model_config.ffn_dim
        elif hasattr(model_config, "intermediate_size"):
            d_ff = model_config.intermediate_size
        elif hasattr(model_config, "d_ff"):
            d_ff = model_config.d_ff
        elif hasattr(model_config, "ff_ratio"):
            d_ff = d_model * model_config.ff_ratio
        elif hasattr(model_config, "ffn_config"):
            d_ff = model_config.ffn_config.ffn_hidden_size
        else:
            raise ValueError("Unknown FFN dimension")
        
        if hasattr(model_config, "num_local_experts"):
            num_experts = model_config.num_local_experts
        elif hasattr(model_config, "num_experts"):
            num_experts = model_config.num_experts
        elif hasattr(model_config, "ffn_config"):
            num_experts = model_config.ffn_config.moe_num_experts
        else:
            num_experts = 1
            
        ffn_params = n_layers * d_ff * linear_count * d_model
        
        shared_params = model_size_param - num_experts * ffn_params

        model_size = shared_params + n_experts_per_tok * ffn_params

        per_token_kv_size = 2 * n_layers * d_model * precision_bytes
        
        peak_bw_single = get_peak_bw(get_gpu_details())
        peak_bw = peak_bw_single * self._detect_num_gpus_used()
        
        context_prefill_size = context_length
        kv_size = context_prefill_size * per_token_kv_size + (output_length - 1) * per_token_kv_size / 2
        
        kv_size = kv_size / 1e9
        
        n_vocab = model_config.vocab_size

        end_to_end_time = (end - start) / batch_size
        prefilling_time = stop_watch.prefilling_time / batch_size
        decoding_time = stop_watch.decoding_time / batch_size
        token_per_sec = output_length / decoding_time
        achieve_mem_bw = (model_size * precision_bytes / 1e9 + kv_size) * token_per_sec
        
        avg_context_length = context_length + (output_length - 1) / 2
        flops_per_token = 2 * model_size + ((linear_count + element_wise_mul) * n_layers * avg_context_length * d_model) + 4 * d_model + 2 * d_model * n_vocab
        peak_flops_single = get_peak_flops(get_gpu_details(), self.precision)
        peak_flops = peak_flops_single * self._detect_num_gpus_used()
        
        ## TODO only support llama-type decoder only models and moe models of switch transformer and mixtrial
        mfu = token_per_sec * flops_per_token / peak_flops
        mbu = achieve_mem_bw / peak_bw
        
        print(f"mfu: {mfu}, mbu: {mbu}")
        
        return res, end_to_end_time, prefilling_time, token_per_sec, mfu, mbu

    def generate_until(
        self, requests: List[Instance], disable_tqdm: bool = False
    ) -> List[str]:
        res = []

        def _collate(req: Tuple[str, dict]):
            """Defines the key for the sorted method"""
            # the negative sign on len(toks) sorts descending - this has a few advantages:
            # - time estimates will always be over not underestimates, which is more useful for planning
            # - to know the size of a batch when going through the list, you know the first one is always the batch
            #   padded context length. this is useful to simplify the batching logic and more importantly to make
            #   automatic adaptive batches much much easier to implement
            # - any OOMs will happen right away rather than near the end
            toks = self.tok_encode(req[0])
            return -len(toks), req[0]

        pbar = tqdm(
            total=len(requests),
            disable=(disable_tqdm or (self.rank != 0)),
            desc="Running generate_until requests",
        )
        adaptive_batch_size = None
        if self.batch_size == "auto":
            # using rolling window with maximum context
            print("Passed argument batch_size = auto. Detecting largest batch size")
            batch_size = self._detect_batch_size()
            print(f"Determined Largest batch size: {batch_size}")
            adaptive_batch_size = batch_size
        # for each different set of kwargs, we execute all requests, by batch.
        batch_size = (
            self.batch_size
            if self.batch_size != "auto"
            else adaptive_batch_size
            if adaptive_batch_size is not None
            else 0
        )
        batch_fn = (
            self._batch_scheduler
            if self.batch_size == "auto" and not adaptive_batch_size
            else None
        )

        # we group requests by their generation_kwargs,
        # so that we don't try to execute e.g. greedy sampling and temp=0.8 sampling
        # in the same batch.
        # group_fn=lambda x: x[1] -> x=(context, gen_kwargs)
        re_ords = Collator(
            [reg.args for reg in requests],
            sort_fn=_collate,
            group_by="gen_kwargs",
            group_fn=lambda x: x[1],
        )
        chunks = re_ords.get_batched(n=batch_size, batch_fn=batch_fn)
        for chunk in chunks:
            contexts, all_gen_kwargs = zip(*chunk)
            # we assume all gen kwargs in the batch are the same
            # this is safe to assume because the `grouper` object ensures it.
            gen_kwargs = all_gen_kwargs[0]
            # unpack our keyword arguments.
            until = None
            if isinstance(gen_kwargs, dict):
                kwargs = copy.deepcopy(gen_kwargs)  # edge case for repeats > 1
                if "until" in kwargs.keys():
                    until = kwargs.pop("until")
                    if isinstance(until, str):
                        until = [kwargs]
                    elif not isinstance(until, list):
                        raise ValueError(
                            f"Expected `kwargs['until']` to be of type Union[str,list] but got {until}"
                        )
            else:
                raise ValueError(
                    f"Expected `kwargs` to be of type `dict` but got {type(gen_kwargs)}"
                )
            # add EOS token to stop sequences
            eos = "<|eot_id|>"
            if not until:
                until = [eos]
            else:
                until.append(eos)
            
            # is_gsm8k = kwargs.get("is_gsm8k", False)
            # if is_gsm8k:
            #     until = ["Question:", "Question", "</s>"]
            #     eos_ids = [self.tokenizer.eos_token_id, 
            #              self.tokenizer.convert_tokens_to_ids("<|eot_id|>")]
                
                    
            if "max_gen_toks" in kwargs.keys():
                max_gen_toks = kwargs.pop("max_gen_toks")
            else:
                max_gen_toks = self.max_gen_toks

            # set the max length in tokens of inputs ("context_enc")
            if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                # max len for inputs = max length, minus room to generate the max new tokens
                max_ctx_len = self.max_length - max_gen_toks
            elif self.AUTO_MODEL_CLASS == transformers.AutoModelForSeq2SeqLM:
                # max len for inputs = encoder's whole max_length
                max_ctx_len = self.max_length

            # encode, pad, and truncate contexts for this batch
            context_enc, attn_masks = self.tok_batch_encode(
                contexts,
                left_truncate_len=max_ctx_len,
                truncation=self.truncation,
            )
            
            # print("context: ", self.tok_decode(context_enc[0]))
            context_enc = context_enc.to(self.device)
            attn_masks = attn_masks.to(self.device)

            if "max_tokens" not in kwargs:
                kwargs["max_tokens"] = max_gen_toks

            # perform batched generation
            cont, end_to_end_time, prefilling_time, token_per_sec, mfu, mbu = self._model_generate(
                context=context_enc,
                attention_mask=attn_masks,
                stop=until,
                **kwargs,
            )

            cont_toks_list = cont.tolist()
            for cont_toks, context in zip(cont_toks_list, contexts):
                # discard context + left-padding toks if using causal decoder-only LM
                if self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM:
                    # print("After Generation: ", self.tok_decode(cont_toks))
                    cont_toks = cont_toks[context_enc.shape[1] :]
                
                s = self.tok_decode(cont_toks)

                # # use secondary stop seqs to cut off should-have-been-stopped content post-hoc
                # if not is_gsm8k:
                for term in until:
                    if len(term) > 0:
                        # ignore '' separator,
                        # for seq2seq case where self.tok_decode(self.eot_token_id) = ''
                        s = s.split(term)[0]
                
                # print(s)
                res.append((s, end_to_end_time, prefilling_time, token_per_sec, mfu, mbu))

                self.cache_hook.add_partial("generate_until", (context, gen_kwargs), s)
                pbar.update(1)
        # reorder this group of results back to original unsorted form
        res = re_ords.get_original(res)

        pbar.close()

        return res