from llama_cpp import Llama
import gc
import threading
import logging
import sys

log = logging.getLogger('llm_api.backend')
    
class LlmBackend:
    
    SYSTEM_PROMPT = "Ты — русскоязычный автоматический ассистент. Ты максимально точно и отвечаешь на запросы пользователя, используя русский язык."
    SYSTEM_TOKEN = 1788
    USER_TOKEN = 1404
    BOT_TOKEN = 9225
    LINEBREAK_TOKEN = 13

    ROLE_TOKENS = {
        "user": USER_TOKEN,
        "bot": BOT_TOKEN,
        "system": SYSTEM_TOKEN
    }

    _instance = None
    _model = None
    _model_params = None
    _lock = threading.Lock()
    
    def __new__(cls):
        if cls._instance is None:
            cls._instance = super(LlmBackend, cls).__new__(cls)
        return cls._instance
    
    
    def is_model_loaded(self):
        return self._model is not None
    
    def load_model(self, model_path, context_size=2000, enable_gpu=True, gpu_layer_number=35, chat_format='llama-2'):
        log.info('load_model - started')
        self._model_params = {}
        self._model_params['model_path'] = model_path
        self._model_params['context_size'] = context_size
        self._model_params['enable_gpu'] = enable_gpu
        self._model_params['gpu_layer_number'] = gpu_layer_number
        self._model_params['chat_format'] = chat_format
        
        if self._model is not None:
            self.unload_model()
            
        with self._lock:    
            if enable_gpu:
                self._model = Llama(
                    model_path=model_path,
                    chat_format=chat_format,
                    n_ctx=context_size,
                    n_parts=1,
                    #n_batch=100,
                    logits_all=True,
                    #n_threads=12,
                    verbose=True,
                    n_gpu_layers=gpu_layer_number
                )
                log.info('load_model - finished')
                return self._model
            else:
                self._model = Llama(
                    model_path=model_path,
                    chat_format=chat_format,
                    n_ctx=context_size,
                    n_parts=1,
                    #n_batch=100,
                    logits_all=True,
                    #n_threads=12,
                    verbose=True
                )
                log.info('load_model - finished')
                return self._model
        
    def set_system_prompt(self, prompt):
        with self._lock:
            self.SYSTEM_PROMPT = prompt
        
    def unload_model(self):
        log.info('unload_model - started')
        with self._lock:
            if self._model is not None:
                del self._model
        log.info('unload_model - finished')
    
    def ensure_model_is_loaded(self):
        log.info('ensure_model_is_loaded - started')
        if not self.is_model_loaded():
            log.info('ensure_model_is_loaded - model reloading')
            if self._model_params is not None:
                self.load_model(**self._model_params)
            else:
                log.info('ensure_model_is_loaded - No model config found. Reloading can not be done.')
        log.info('ensure_model_is_loaded - finished')
                    
    def generate_tokens(self, generator):
        log.info('generate_tokens - started')
        with self._lock:
            self.ensure_model_is_loaded()
                
            try:
                for token in generator:            
                    if token == self._model.token_eos():
                        log.info('generate_tokens - finished')
                        yield b''  # End of chunk
                        break
                        
                    token_str = self._model.detokenize([token])#.decode("utf-8", errors="ignore")
                    yield token_str 
            except Exception as e:
                log.error('generate_tokens - error')
                log.error(e)
                yield b''  # End of chunk
                
    def create_chat_completion(self, messages, stream=True):
        log.info('create_chat_completion called')
        with self._lock:
            log.info('create_chat_completion started')
            try:
                return self._model.create_chat_completion(messages=messages, stream=stream)
            except Exception as e:
                log.error('create_chat_completion - error')
                log.error(e)
                return None
                
    
    def get_message_tokens(self, role, content):
        log.info('get_message_tokens - started')
        self.ensure_model_is_loaded()
        message_tokens = self._model.tokenize(content.encode("utf-8"))
        message_tokens.insert(1, self.ROLE_TOKENS[role])
        message_tokens.insert(2, self.LINEBREAK_TOKEN)
        message_tokens.append(self._model.token_eos())
        log.info('get_message_tokens - finished')
        return message_tokens

    def get_system_tokens(self):
        return self.get_message_tokens(role="system", content=self.SYSTEM_PROMPT)
    
    def create_chat_generator_for_saiga(self, messages, parameters, use_system_prompt=True):
        log.info('create_chat_generator_for_saiga - started')
        with self._lock:
            self.ensure_model_is_loaded()
            tokens = self.get_system_tokens() if use_system_prompt else []
            for message in messages:
                message_tokens = self.get_message_tokens(role=message.get("from"), content=message.get("content", ""))
                tokens.extend(message_tokens)
            
            tokens.extend([self._model.token_bos(), self.BOT_TOKEN, self.LINEBREAK_TOKEN])
            generator = self._model.generate(
                tokens,
                top_k=parameters['top_k'],
                top_p=parameters['top_p'],
                temp=parameters['temperature'],
                repeat_penalty=parameters['repetition_penalty']
            )
            log.info('create_chat_generator_for_saiga - finished')
            return generator
        
    def generate_tokens(self, generator):
        log.info('generate_tokens - started')
        with self._lock:
            self.ensure_model_is_loaded()
            try:
                for token in generator:            
                    if token == self._model.token_eos():
                        yield b''  # End of chunk
                        log.info('generate_tokens - finished')
                        break
                        
                    token_str = self._model.detokenize([token])#.decode("utf-8", errors="ignore")
                    yield token_str 
            except Exception as e:
                log.error('generate_tokens - error')
                log.error(e)
                yield b''  # End of chunk