Spaces:
Paused
Paused
| import json | |
| import re | |
| import requests | |
| from tiktoken import get_encoding as tiktoken_get_encoding | |
| from messagers.message_outputer import OpenaiStreamOutputer | |
| from utils.logger import logger | |
| from utils.enver import enver | |
| class MessageStreamer: | |
| MODEL_MAP = { | |
| "mixtral-8x7b": "mistralai/Mixtral-8x7B-Instruct-v0.1", # 72.62, fast [Recommended] | |
| "mistral-7b": "mistralai/Mistral-7B-Instruct-v0.2", # 65.71, fast | |
| "nous-mixtral-8x7b": "NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO", | |
| "gemma-7b": "google/gemma-7b-it", | |
| # "openchat-3.5": "openchat/openchat-3.5-1210", # 68.89, fast | |
| # "zephyr-7b-beta": "HuggingFaceH4/zephyr-7b-beta", # ❌ Too Slow | |
| # "llama-70b": "meta-llama/Llama-2-70b-chat-hf", # ❌ Require Pro User | |
| # "codellama-34b": "codellama/CodeLlama-34b-Instruct-hf", # ❌ Low Score | |
| # "falcon-180b": "tiiuae/falcon-180B-chat", # ❌ Require Pro User | |
| "default": "mistralai/Mixtral-8x7B-Instruct-v0.1", | |
| } | |
| STOP_SEQUENCES_MAP = { | |
| "mixtral-8x7b": "</s>", | |
| "mistral-7b": "</s>", | |
| "nous-mixtral-8x7b": "<|im_end|>", | |
| "openchat-3.5": "<|end_of_turn|>", | |
| "gemma-7b": "<eos>", | |
| } | |
| TOKEN_LIMIT_MAP = { | |
| "mixtral-8x7b": 32768, | |
| "mistral-7b": 32768, | |
| "nous-mixtral-8x7b": 32768, | |
| "openchat-3.5": 8192, | |
| "gemma-7b": 8192, | |
| } | |
| TOKEN_RESERVED = 100 | |
| def __init__(self, model: str): | |
| if model in self.MODEL_MAP.keys(): | |
| self.model = model | |
| else: | |
| self.model = "default" | |
| self.model_fullname = self.MODEL_MAP[self.model] | |
| self.message_outputer = OpenaiStreamOutputer() | |
| self.tokenizer = tiktoken_get_encoding("cl100k_base") | |
| def parse_line(self, line): | |
| line = line.decode("utf-8") | |
| line = re.sub(r"data:\s*", "", line) | |
| data = json.loads(line) | |
| try: | |
| content = data["token"]["text"] | |
| except: | |
| logger.err(data) | |
| return content | |
| def count_tokens(self, text): | |
| tokens = self.tokenizer.encode(text) | |
| token_count = len(tokens) | |
| logger.note(f"Prompt Token Count: {token_count}") | |
| return token_count | |
| def chat_response( | |
| self, | |
| prompt: str = None, | |
| temperature: float = 0.5, | |
| top_p: float = 0.95, | |
| max_new_tokens: int = None, | |
| api_key: str = None, | |
| use_cache: bool = False, | |
| ): | |
| # https://huggingface.co/docs/api-inference/detailed_parameters?code=curl | |
| # curl --proxy http://<server>:<port> https://api-inference.huggingface.co/models/<org>/<model_name> -X POST -d '{"inputs":"who are you?","parameters":{"max_new_token":64}}' -H 'Content-Type: application/json' -H 'Authorization: Bearer <HF_TOKEN>' | |
| self.request_url = ( | |
| f"https://api-inference.huggingface.co/models/{self.model_fullname}" | |
| ) | |
| self.request_headers = { | |
| "Content-Type": "application/json", | |
| } | |
| if api_key: | |
| logger.note( | |
| f"Using API Key: {api_key[:3]}{(len(api_key)-7)*'*'}{api_key[-4:]}" | |
| ) | |
| self.request_headers["Authorization"] = f"Bearer {api_key}" | |
| if temperature is None or temperature < 0: | |
| temperature = 0.0 | |
| # temperature must 0 < and < 1 for HF LLM models | |
| temperature = max(temperature, 0.01) | |
| temperature = min(temperature, 0.99) | |
| top_p = max(top_p, 0.01) | |
| top_p = min(top_p, 0.99) | |
| token_limit = int( | |
| self.TOKEN_LIMIT_MAP[self.model] | |
| - self.TOKEN_RESERVED | |
| - self.count_tokens(prompt) * 1.35 | |
| ) | |
| if token_limit <= 0: | |
| raise ValueError("Prompt exceeded token limit!") | |
| if max_new_tokens is None or max_new_tokens <= 0: | |
| max_new_tokens = token_limit | |
| else: | |
| max_new_tokens = min(max_new_tokens, token_limit) | |
| # References: | |
| # huggingface_hub/inference/_client.py: | |
| # class InferenceClient > def text_generation() | |
| # huggingface_hub/inference/_text_generation.py: | |
| # class TextGenerationRequest > param `stream` | |
| # https://huggingface.co/docs/text-generation-inference/conceptual/streaming#streaming-with-curl | |
| # https://huggingface.co/docs/api-inference/detailed_parameters#text-generation-task | |
| self.request_body = { | |
| "inputs": prompt, | |
| "parameters": { | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "max_new_tokens": max_new_tokens, | |
| "return_full_text": False, | |
| }, | |
| "options": { | |
| "use_cache": use_cache, | |
| }, | |
| "stream": True, | |
| } | |
| if self.model in self.STOP_SEQUENCES_MAP.keys(): | |
| self.stop_sequences = self.STOP_SEQUENCES_MAP[self.model] | |
| # self.request_body["parameters"]["stop_sequences"] = [ | |
| # self.STOP_SEQUENCES[self.model] | |
| # ] | |
| logger.back(self.request_url) | |
| enver.set_envs(proxies=True) | |
| stream_response = requests.post( | |
| self.request_url, | |
| headers=self.request_headers, | |
| json=self.request_body, | |
| proxies=enver.requests_proxies, | |
| stream=True, | |
| ) | |
| status_code = stream_response.status_code | |
| if status_code == 200: | |
| logger.success(status_code) | |
| else: | |
| logger.err(status_code) | |
| return stream_response | |
| def chat_return_dict(self, stream_response): | |
| # https://platform.openai.com/docs/guides/text-generation/chat-completions-response-format | |
| final_output = self.message_outputer.default_data.copy() | |
| final_output["choices"] = [ | |
| { | |
| "index": 0, | |
| "finish_reason": "stop", | |
| "message": { | |
| "role": "assistant", | |
| "content": "", | |
| }, | |
| } | |
| ] | |
| logger.back(final_output) | |
| final_content = "" | |
| for line in stream_response.iter_lines(): | |
| if not line: | |
| continue | |
| content = self.parse_line(line) | |
| if content.strip() == self.stop_sequences: | |
| logger.success("\n[Finished]") | |
| break | |
| else: | |
| logger.back(content, end="") | |
| final_content += content | |
| if self.model in self.STOP_SEQUENCES_MAP.keys(): | |
| final_content = final_content.replace(self.stop_sequences, "") | |
| final_content = final_content.strip() | |
| final_output["choices"][0]["message"]["content"] = final_content | |
| return final_output | |
| def chat_return_generator(self, stream_response): | |
| is_finished = False | |
| line_count = 0 | |
| for line in stream_response.iter_lines(): | |
| if line: | |
| line_count += 1 | |
| else: | |
| continue | |
| content = self.parse_line(line) | |
| if content.strip() == self.stop_sequences: | |
| content_type = "Finished" | |
| logger.success("\n[Finished]") | |
| is_finished = True | |
| else: | |
| content_type = "Completions" | |
| if line_count == 1: | |
| content = content.lstrip() | |
| logger.back(content, end="") | |
| output = self.message_outputer.output( | |
| content=content, content_type=content_type | |
| ) | |
| yield output | |
| if not is_finished: | |
| yield self.message_outputer.output(content="", content_type="Finished") | |