from pathlib import Path

import tensorrt_llm
import torch
from tensorrt_llm.runtime import ModelRunner, ModelRunnerCpp

from modules import shared
from modules.logging_colors import logger
from modules.text_generation import (
    get_max_prompt_length,
    get_reply_from_output_ids
)


class TensorRTLLMModel:
    def __init__(self):
        pass

    @classmethod
    def from_pretrained(self, path_to_model):

        path_to_model = Path(f'{shared.args.model_dir}') / Path(path_to_model)
        runtime_rank = tensorrt_llm.mpi_rank()

        # Define model settings
        runner_kwargs = dict(
            engine_dir=str(path_to_model),
            lora_dir=None,
            rank=runtime_rank,
            debug_mode=False,
            lora_ckpt_source="hf",
        )

        if shared.args.cpp_runner:
            logger.info("TensorRT-LLM: Using \"ModelRunnerCpp\"")
            runner_kwargs.update(
                max_batch_size=1,
                max_input_len=shared.args.max_seq_len - 512,
                max_output_len=512,
                max_beam_width=1,
                max_attention_window_size=None,
                sink_token_length=None,
            )
        else:
            logger.info("TensorRT-LLM: Using \"ModelRunner\"")

        # Load the model
        runner_cls = ModelRunnerCpp if shared.args.cpp_runner else ModelRunner
        runner = runner_cls.from_dir(**runner_kwargs)

        result = self()
        result.model = runner
        result.runtime_rank = runtime_rank

        return result

    def generate_with_streaming(self, prompt, state):
        batch_input_ids = []
        input_ids = shared.tokenizer.encode(
            prompt,
            add_special_tokens=True,
            truncation=False,
        )
        input_ids = torch.tensor(input_ids, dtype=torch.int32)
        input_ids = input_ids[-get_max_prompt_length(state):]  # Apply truncation_length
        batch_input_ids.append(input_ids)

        if shared.args.cpp_runner:
            max_new_tokens = min(512, state['max_new_tokens'])
        elif state['auto_max_new_tokens']:
            max_new_tokens = state['truncation_length'] - input_ids.shape[-1]
        else:
            max_new_tokens = state['max_new_tokens']

        with torch.no_grad():
            generator = self.model.generate(
                batch_input_ids,
                max_new_tokens=max_new_tokens,
                max_attention_window_size=None,
                sink_token_length=None,
                end_id=shared.tokenizer.eos_token_id if not state['ban_eos_token'] else -1,
                pad_id=shared.tokenizer.pad_token_id or shared.tokenizer.eos_token_id,
                temperature=state['temperature'],
                top_k=state['top_k'],
                top_p=state['top_p'],
                num_beams=1,
                length_penalty=1.0,
                repetition_penalty=state['repetition_penalty'],
                presence_penalty=state['presence_penalty'],
                frequency_penalty=state['frequency_penalty'],
                stop_words_list=None,
                bad_words_list=None,
                lora_uids=None,
                prompt_table_path=None,
                prompt_tasks=None,
                streaming=not shared.args.cpp_runner,
                output_sequence_lengths=True,
                return_dict=True,
                medusa_choices=None
            )

        torch.cuda.synchronize()

        cumulative_reply = ''
        starting_from = batch_input_ids[0].shape[-1]

        if shared.args.cpp_runner:
            sequence_length = generator['sequence_lengths'][0].item()
            output_ids = generator['output_ids'][0][0][:sequence_length].tolist()

            cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from)
            starting_from = sequence_length
            yield cumulative_reply
        else:
            for curr_outputs in generator:
                if shared.stop_everything:
                    break

                sequence_length = curr_outputs['sequence_lengths'][0].item()
                output_ids = curr_outputs['output_ids'][0][0][:sequence_length].tolist()

                cumulative_reply += get_reply_from_output_ids(output_ids, state, starting_from=starting_from)
                starting_from = sequence_length
                yield cumulative_reply

    def generate(self, prompt, state):
        output = ''
        for output in self.generate_with_streaming(prompt, state):
            pass

        return output