import logging
from auto_gptq import AutoGPTQForCausalLM
from huggingface_hub import hf_hub_download
# from langchain.llms import LlamaCpp 
from langchain_community.llms import LlamaCpp
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    LlamaForCausalLM,
    LlamaTokenizer,
)
from langchain_community.llms import HuggingFacePipeline
from langchain.callbacks.manager import CallbackManager
from transformers import GenerationConfig, pipeline
import torch
import os
from app.settings import Config

conf = Config()


logger = logging.getLogger(__name__)

MODELS_PATH = conf.MODELS_PATH

CONTEXT_WINDOW_SIZE = 2048
MAX_NEW_TOKENS = 2048
N_BATCH= 512
N_GPU_LAYERS = 1

CACHE_DIR = conf.CACHE_DIR #"./models/"

os.environ["HUGGINGFACEHUB_API_TOKEN"] = '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'

def load_quantized_model_gguf_ggml(model_id, model_basename, device_type, logging):

    try:
        logging.info("Using Llamacpp for GGUF/GGML quantized models")
        model_path = hf_hub_download(
            repo_id=model_id,
            filename=model_basename,
            resume_download=True,
            # force_download=True,
            cache_dir=MODELS_PATH,
        )
        kwargs = {
            "model_path": model_path,
            "n_ctx": CONTEXT_WINDOW_SIZE,
            "max_tokens": MAX_NEW_TOKENS,
            "n_batch": N_BATCH,  # set this based on your GPU & CPU RAM
        }
        if device_type.lower() == "mps":
            kwargs["n_gpu_layers"] = 1
        if device_type.lower() == "cuda":
            kwargs["n_gpu_layers"] = N_GPU_LAYERS  # set this based on your GPU

        return LlamaCpp(**kwargs)
    except:
        if "ggml" in model_basename:
            logging.INFO("If you were using GGML model, LLAMA-CPP Dropped Support, Use GGUF Instead")
        return None


def load_quantized_model_qptq(model_id, model_basename, device_type, logging):
    logging.info("Using AutoGPTQForCausalLM for quantized models")

    if ".safetensors" in model_basename:
        # Remove the ".safetensors" ending if present
        model_basename = model_basename.replace(".safetensors", "")

    tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=True)
    logging.info("Tokenizer loaded")

    model = AutoGPTQForCausalLM.from_quantized(
        model_id,
        model_basename=model_basename,
        use_safetensors=True,
        trust_remote_code=True,
        device_map="auto",
        use_triton=False,
        quantize_config=None,
    )
    return model, tokenizer

def load_full_model(model_id, model_basename, device_type, logging):

    if device_type.lower() in ["mps", "cpu"]:
        logging.info("Using LlamaTokenizer")
        tokenizer = LlamaTokenizer.from_pretrained(model_id, cache_dir=CACHE_DIR, use_auth_token=os.environ["HUGGINGFACEHUB_API_TOKEN"]) #
        model = LlamaForCausalLM.from_pretrained(model_id, cache_dir=CACHE_DIR, use_auth_token=os.environ["HUGGINGFACEHUB_API_TOKEN"]) #, cache_dir=CACHE_DIR
    else:
        logging.info("Using AutoModelForCausalLM for full models")
        tokenizer = AutoTokenizer.from_pretrained(model_id, cache_dir=CACHE_DIR, use_auth_token=os.environ["HUGGINGFACEHUB_API_TOKEN"]) #, cache_dir=CACHE_DIR
        logging.info("Tokenizer loaded")
        model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map="auto",
            torch_dtype=torch.float16,
            low_cpu_mem_usage=True,
            cache_dir=MODELS_PATH,
            use_auth_token=os.environ["HUGGINGFACEHUB_API_TOKEN"]
            # trust_remote_code=True, # set these if you are using NVIDIA GPU
            # load_in_4bit=True,
            # bnb_4bit_quant_type="nf4",
            # bnb_4bit_compute_dtype=torch.float16,
            # max_memory={0: "15GB"} # Uncomment this line with you encounter CUDA out of memory errors
        )
        model.tie_weights()
    return model, tokenizer



def load_model(device_type, model_id, model_basename=None, LOGGING=logger):
    logger.info(f"Loading Model: {model_id}, on: {device_type}")
    logger.info("This action can take a few minutes!")

    if model_basename is not None:
        if ".gguf" in model_basename.lower():
            llm = load_quantized_model_gguf_ggml(
                model_id, model_basename, device_type, LOGGING)
            return llm
        elif ".ggml" in model_basename.lower():
            model, tokenizer = load_quantized_model_gguf_ggml(
                model_id, model_basename, device_type, LOGGING)
        else:
            model, tokenizer = load_quantized_model_qptq(
                model_id, model_basename, device_type, LOGGING)
    else:
        model, tokenizer = load_full_model(
            model_id, model_basename, device_type, LOGGING)

    # Load configuration from the model to avoid warnings
    generation_config = GenerationConfig.from_pretrained(model_id)

    pipe = pipeline(
        "text-generation",
        model=model,
        tokenizer=tokenizer,
        max_length=MAX_NEW_TOKENS,
        temperature=0.1,
        # top_p=0.95,
        repetition_penalty=1.15,
        generation_config=generation_config,
    )

    local_llm = HuggingFacePipeline(pipeline=pipe)
    logger.info("Local LLM Loaded")
    return local_llm