File size: 3,772 Bytes
aa4f694
8537019
e690364
 
 
9c0dccd
 
3e68ccf
 
 
 
724babe
4bd6659
724babe
9c0dccd
aa4f694
e690364
 
 
 
 
 
 
 
 
 
 
 
9c0dccd
813ce6e
9c0dccd
e690364
9c0dccd
813ce6e
 
 
9c0dccd
 
813ce6e
9c0dccd
 
813ce6e
 
9c0dccd
 
 
 
 
 
 
 
 
3e68ccf
 
813ce6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3e68ccf
 
8537019
 
 
 
 
3e68ccf
8537019
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import logging
import requests
from requests.adapters import HTTPAdapter
from urllib3.util import Retry

from langchain_community.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain_core.language_models import LLM

from global_config import GlobalConfig


HF_API_HEADERS = {"Authorization": f"Bearer {GlobalConfig.HUGGINGFACEHUB_API_TOKEN}"}
REQUEST_TIMEOUT = 35

logger = logging.getLogger(__name__)

retries = Retry(
    total=5,
    backoff_factor=0.25,
    backoff_jitter=0.3,
    status_forcelist=[502, 503, 504],
    allowed_methods={'POST'},
)
adapter = HTTPAdapter(max_retries=retries)
http_session = requests.Session()
http_session.mount('https://', adapter)
http_session.mount('http://', adapter)


def get_hf_endpoint(repo_id: str, max_new_tokens: int) -> LLM:
    """
    Get an LLM via the HuggingFaceEndpoint of LangChain.

    :param repo_id: The model name.
    :param max_new_tokens: The max new tokens to generate.
    :return: The HF LLM inference endpoint.
    """

    logger.debug('Getting LLM via HF endpoint: %s', repo_id)

    return HuggingFaceEndpoint(
        repo_id=repo_id,
        max_new_tokens=max_new_tokens,
        top_k=40,
        top_p=0.95,
        temperature=GlobalConfig.LLM_MODEL_TEMPERATURE,
        repetition_penalty=1.03,
        streaming=True,
        huggingfacehub_api_token=GlobalConfig.HUGGINGFACEHUB_API_TOKEN,
        return_full_text=False,
        stop_sequences=['</s>'],
    )


# def hf_api_query(payload: dict) -> dict:
#     """
#     Invoke HF inference end-point API.
#
#     :param payload: The prompt for the LLM and related parameters.
#     :return: The output from the LLM.
#     """
#
#     try:
#         response = http_session.post(
#             HF_API_URL,
#             headers=HF_API_HEADERS,
#             json=payload,
#             timeout=REQUEST_TIMEOUT
#         )
#         result = response.json()
#     except requests.exceptions.Timeout as te:
#         logger.error('*** Error: hf_api_query timeout! %s', str(te))
#         result = []
#
#     return result


# def generate_slides_content(topic: str) -> str:
#     """
#     Generate the outline/contents of slides for a presentation on a given topic.
#
#     :param topic: Topic on which slides are to be generated.
#     :return: The content in JSON format.
#     """
#
#     with open(GlobalConfig.SLIDES_TEMPLATE_FILE, 'r', encoding='utf-8') as in_file:
#         template_txt = in_file.read().strip()
#         template_txt = template_txt.replace('<REPLACE_PLACEHOLDER>', topic)
#
#     output = hf_api_query({
#         'inputs': template_txt,
#         'parameters': {
#             'temperature': GlobalConfig.LLM_MODEL_TEMPERATURE,
#             'min_length': GlobalConfig.LLM_MODEL_MIN_OUTPUT_LENGTH,
#             'max_length': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
#             'max_new_tokens': GlobalConfig.LLM_MODEL_MAX_OUTPUT_LENGTH,
#             'num_return_sequences': 1,
#             'return_full_text': False,
#             # "repetition_penalty": 0.0001
#         },
#         'options': {
#             'wait_for_model': True,
#             'use_cache': True
#         }
#     })
#
#     output = output[0]['generated_text'].strip()
#     # output = output[len(template_txt):]
#
#     json_end_idx = output.rfind('```')
#     if json_end_idx != -1:
#         # logging.debug(f'{json_end_idx=}')
#         output = output[:json_end_idx]
#
#     logger.debug('generate_slides_content: output: %s', output)
#
#     return output


if __name__ == '__main__':
    # results = get_related_websites('5G AI WiFi 6')
    #
    # for a_result in results.results:
    #     print(a_result.title, a_result.url, a_result.extract)

    # get_ai_image('A talk on AI, covering pros and cons')
    pass