import torch
import torch.nn.functional as F


def apply_temperature(scores, tempt):
    if tempt > 0:
        scores = scores / tempt
    return scores


def apply_top_p(scores, top_p, filter_value=-float("Inf"), min_tokens_to_keep=1):
    if top_p > 0 and top_p < 1:
        sorted_logits, sorted_indices = torch.sort(scores, descending=False)
        cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)

        # Remove tokens with cumulative top_p above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep
            sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(
            1, sorted_indices, sorted_indices_to_remove
        )
        scores = scores.masked_fill(indices_to_remove, filter_value)
    return scores


def apply_top_k(logits, top_k):
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits.float(), top_k)[0][..., -1, None]
        logits[indices_to_remove] = -float("Inf")

    return logits

def apply_advanced_repetition_penalty(
    input_ids, scores, penalty_range, penalty_slope, penalty
):
    penalty_range = int(penalty_range)
    clipped_penalty_range = min(input_ids.shape[-1], penalty_range)

    if penalty != 1.0:
        if penalty_range > 0:
            if clipped_penalty_range < input_ids.shape[1]:
                input_ids = input_ids[..., -clipped_penalty_range:]

            if penalty_slope != 0:
                _penalty = (
                    torch.arange(
                        penalty_range, dtype=scores.dtype, device=scores.device
                    )
                    / (penalty_range - 1)
                ) * 2.0 - 1
                _penalty = (penalty_slope * _penalty) / (
                    1 + torch.abs(_penalty) * (penalty_slope - 1)
                )
                _penalty = 1 + ((_penalty + 1) / 2).unsqueeze(0) * (penalty - 1)
                penalty = _penalty[..., -clipped_penalty_range:]

        score = torch.gather(scores, 1, input_ids)
        score = torch.where(score <= 0, score * penalty, score / penalty)
        scores.scatter_(1, input_ids, score)

    return scores


class LmGeneration:
    def __init__(self, model, tokenizer):
        self.model = model
        self.tokenizer = tokenizer

    def generate(self, args, prompts, cut_off=None, cut_off_times=1):
        if cut_off is not None:
            cut_off_times = [cut_off_times for i in range(len(prompts))]
        batch = len(prompts)
        assert batch <= args.batch_size

        prompt_tokens = [args.tokenizer.encode(x, bos=True, eos=False) for x in prompts]

        min_prompt_len = min([len(x) for x in prompt_tokens])
        # max_prompt_len = max([len(x) for x in prompt_tokens])

        total_len = args.seq_length

        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        tokens = torch.full((batch, total_len), self.tokenizer.pad_token).to(device).long()
        for idx, t in enumerate(prompt_tokens):
            tokens[idx, : len(t)] = torch.tensor(t).long()
        mask = tokens != self.tokenizer.pad_token
        start_pos = min_prompt_len
        prev_pos = 0
        continue_exsample = [i for i in range(batch)]
        with torch.no_grad():
            for cur_pos in range(start_pos, total_len):
                logits = self.model.forward(tokens[continue_exsample, prev_pos:cur_pos], prev_pos, continue_exsample).float()
                next_token_scores = apply_top_k(logits, top_k=args.top_k)
                next_token_scores = apply_top_p(next_token_scores, args.top_p)
                next_token_scores = apply_temperature(next_token_scores, args.temperature)
                next_token_scores = apply_advanced_repetition_penalty(
                    tokens[continue_exsample, :cur_pos],
                    next_token_scores,
                    args.repetition_penalty_range,
                    args.repetition_penalty_slope,
                    args.repetition_penalty
                )
                scores = F.softmax(next_token_scores, dim=-1)
                next_token = torch.multinomial(scores, num_samples=1).squeeze(1)
                next_token = next_token.reshape(-1)
                next_token = torch.where(
                    mask[continue_exsample, cur_pos], tokens[continue_exsample, cur_pos], next_token
                )
                tokens[continue_exsample, cur_pos] = next_token
                prev_pos = cur_pos
                # remove eos examples.
                continue_exsample = []
                for i, t in enumerate(tokens.tolist()):
                    try:
                        t.index(self.tokenizer.eos_token)
                    except ValueError:
                        if cut_off is not None:
                            if cut_off == self.tokenizer.decode(t[:cur_pos + 1])[-len(cut_off):]:
                                if cut_off_times[i] == 1:
                                    continue
                                else:
                                    cut_off_times[i] -= 1
                        continue_exsample.append(i)
                if len(continue_exsample) == 0:
                    break

        decoder = []
        for i, t in enumerate(tokens.tolist()):
            t = t[: args.seq_length]
            try:
                t = t[: t.index(self.tokenizer.pad_token)]
                t = t[: t.index(self.tokenizer.eos_token)]
            except ValueError:
                pass
            decoder.append(self.tokenizer.decode(t))

        return decoder