import logging
import requests
from requests.adapters import HTTPAdapter
from urllib3.util import Retry

from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain_core.language_models import LLM

from global_config import GlobalConfig


HF_API_HEADERS = {"Authorization": f"Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}"}
REQUEST_TIMEOUT = 35

logger = logging.getLogger(__name__)

retries = Retry(
    total=5,
    backoff_factor=0.25,
    backoff_jitter=0.3,
    status_forcelist=[502, 503, 504],
    allowed_methods={'POST'},
)
adapter = HTTPAdapter(max_retries=retries)
http_session = requests.Session()
http_session.mount('https://', adapter)
http_session.mount('http://', adapter)


def get_hf_endpoint(repo_id: str, max_new_tokens: int) -> LLM:
    """
    Get an LLM via the HuggingFaceEndpoint of LangChain.

    :param repo_id: The model name.
    :param max_new_tokens: The max new tokens to generate.
    :return: The HF LLM inference endpoint.
    """

    logger.debug('Getting LLM via HF endpoint: %s', repo_id)

    return HuggingFaceEndpoint(
        repo_id=repo_id,
        max_new_tokens=max_new_tokens,
        top_k=40,
        top_p=0.95,
        temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
        repetition_penalty=1.03,
        streaming=True,
        huggingfacehub_api_token=GlobalConfig.HUGGINGFACEHUB_API_TOKEN,
        return_full_text=False,
        stop_sequences=['</s>'],
    )


# def hf_api_query(payload: dict) -> dict:
#     """
#     Invoke HF inference end-point API.
#
#     :param payload: The prompt for the LLM and related parameters.
#     :return: The output from the LLM.
#     """
#
#     try:
#         response = http_session.post(
#             HF_API_URL,
#             headers=HF_API_HEADERS,
#             json=payload,
#             timeout=REQUEST_TIMEOUT
#         )
#         result = response.json()
#     except requests.exceptions.Timeout as te:
#         logger.error('*** Error: hf_api_query timeout! %s', str(te))
#         result = []
#
#     return result


# def generate_slides_content(topic: str) -> str:
#     """
#     Generate the outline/contents of slides for a presentation on a given topic.
#
#     :param topic: Topic on which slides are to be generated.
#     :return: The content in JSON format.
#     """
#
#     with open(GlobalConfig.SLIDES_TEMPLATE_FILE, 'r', encoding='utf-8') as in_file:
#         template_txt = in_file.read().strip()
#         template_txt = template_txt.replace('<REPLACE_PLACEHOLDER>', topic)
#
#     output = hf_api_query({
#         'inputs': template_txt,
#         'parameters': {
#             'temperature': GlobalConfig.LLM_MODEL_TEMPERATURE,
#             'min_length': GlobalConfig.LLM_MODEL_MIN_OUTPUT_LENGTH,
#             'max_length': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
#             'max_new_tokens': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
#             'num_return_sequences': 1,
#             'return_full_text': False,
#             # "repetition_penalty": 0.0001
#         },
#         'options': {
#             'wait_for_model': True,
#             'use_cache': True
#         }
#     })
#
#     output = output[0]['generated_text'].strip()
#     # output = output[len(template_txt):]
#
#     json_end_idx = output.rfind('```')
#     if json_end_idx != -1:
#         # logging.debug(f'{json_end_idx=}')
#         output = output[:json_end_idx]
#
#     logger.debug('generate_slides_content: output: %s', output)
#
#     return output


if __name__ == '__main__':
    # results = get_related_websites('5G AI WiFi 6')
    #
    # for a_result in results.results:
    #     print(a_result.title, a_result.url, a_result.extract)

    # get_ai_image('A talk on AI, covering pros and cons')
    pass