import json
import os

import colorama
import requests
import logging

from modules.models.base_model import BaseLLMModel
from modules.presets import STANDARD_ERROR_MSG, GENERAL_ERROR_MSG, TIMEOUT_STREAMING, TIMEOUT_ALL, i18n

group_id = os.environ.get("MINIMAX_GROUP_ID", "")


class MiniMax_Client(BaseLLMModel):
    """
    MiniMax Client
    接口文档见 https://api.minimax.chat/document/guides/chat
    """

    def __init__(self, model_name, api_key, user_name="", system_prompt=None):
        super().__init__(model_name=model_name, user=user_name)
        self.url = f'https://api.minimax.chat/v1/text/chatcompletion?GroupId={group_id}'
        self.history = []
        self.api_key = api_key
        self.system_prompt = system_prompt
        self.headers = {
            "Authorization": f"Bearer {api_key}",
            "Content-Type": "application/json"
        }

    def get_answer_at_once(self):
        # minimax temperature is (0,1] and base model temperature is [0,2], and yuan 0.9 == base 1 so need to convert
        temperature = self.temperature * 0.9 if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10

        request_body = {
            "model": self.model_name.replace('minimax-', ''),
            "temperature": temperature,
            "skip_info_mask": True,
            'messages': [{"sender_type": "USER", "text": self.history[-1]['content']}]
        }
        if self.n_choices:
            request_body['beam_width'] = self.n_choices
        if self.system_prompt:
            request_body['prompt'] = self.system_prompt
        if self.max_generation_token:
            request_body['tokens_to_generate'] = self.max_generation_token
        if self.top_p:
            request_body['top_p'] = self.top_p

        response = requests.post(self.url, headers=self.headers, json=request_body)

        res = response.json()
        answer = res['reply']
        total_token_count = res["usage"]["total_tokens"]
        return answer, total_token_count

    def get_answer_stream_iter(self):
        response = self._get_response(stream=True)
        if response is not None:
            iter = self._decode_chat_response(response)
            partial_text = ""
            for i in iter:
                partial_text += i
                yield partial_text
        else:
            yield STANDARD_ERROR_MSG + GENERAL_ERROR_MSG

    def _get_response(self, stream=False):
        minimax_api_key = self.api_key
        history = self.history
        logging.debug(colorama.Fore.YELLOW +
                      f"{history}" + colorama.Fore.RESET)
        headers = {
            "Content-Type": "application/json",
            "Authorization": f"Bearer {minimax_api_key}",
        }

        temperature = self.temperature * 0.9 if self.temperature <= 1 else 0.9 + (self.temperature - 1) / 10

        messages = []
        for msg in self.history:
            if msg['role'] == 'user':
                messages.append({"sender_type": "USER", "text": msg['content']})
            else:
                messages.append({"sender_type": "BOT", "text": msg['content']})

        request_body = {
            "model": self.model_name.replace('minimax-', ''),
            "temperature": temperature,
            "skip_info_mask": True,
            'messages': messages
        }
        if self.n_choices:
            request_body['beam_width'] = self.n_choices
        if self.system_prompt:
            lines = self.system_prompt.splitlines()
            if lines[0].find(":") != -1 and len(lines[0]) < 20:
                request_body["role_meta"] = {
                    "user_name": lines[0].split(":")[0],
                    "bot_name": lines[0].split(":")[1]
                }
                lines.pop()
            request_body["prompt"] = "\n".join(lines)
        if self.max_generation_token:
            request_body['tokens_to_generate'] = self.max_generation_token
        else:
            request_body['tokens_to_generate'] = 512
        if self.top_p:
            request_body['top_p'] = self.top_p

        if stream:
            timeout = TIMEOUT_STREAMING
            request_body['stream'] = True
            request_body['use_standard_sse'] = True
        else:
            timeout = TIMEOUT_ALL
        try:
            response = requests.post(
                self.url,
                headers=headers,
                json=request_body,
                stream=stream,
                timeout=timeout,
            )
        except:
            return None

        return response

    def _decode_chat_response(self, response):
        error_msg = ""
        for chunk in response.iter_lines():
            if chunk:
                chunk = chunk.decode()
                chunk_length = len(chunk)
                print(chunk)
                try:
                    chunk = json.loads(chunk[6:])
                except json.JSONDecodeError:
                    print(i18n("JSON解析错误,收到的内容: ") + f"{chunk}")
                    error_msg += chunk
                    continue
                if chunk_length > 6 and "delta" in chunk["choices"][0]:
                    if "finish_reason" in chunk["choices"][0] and chunk["choices"][0]["finish_reason"] == "stop":
                        self.all_token_counts.append(chunk["usage"]["total_tokens"] - sum(self.all_token_counts))
                        break
                    try:
                        yield chunk["choices"][0]["delta"]
                    except Exception as e:
                        logging.error(f"Error: {e}")
                        continue
        if error_msg:
            try:
                error_msg = json.loads(error_msg)
                if 'base_resp' in error_msg:
                    status_code = error_msg['base_resp']['status_code']
                    status_msg = error_msg['base_resp']['status_msg']
                    raise Exception(f"{status_code} - {status_msg}")
            except json.JSONDecodeError:
                pass
            raise Exception(error_msg)