# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from zhipuai import ZhipuAI from dashscope import Generation from abc import ABC from openai import OpenAI import openai from rag.nlp import is_english from rag.utils import num_tokens_from_string class Base(ABC): def __init__(self, key, model_name): pass def chat(self, system, history, gen_conf): raise NotImplementedError("Please implement encode method!") class GptTurbo(Base): def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): if not base_url: base_url="https://api.openai.com/v1" self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name def chat(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) try: response = self.client.chat.completions.create( model=self.model_name, messages=history, **gen_conf) ans = response.choices[0].message.content.strip() if response.choices[0].finish_reason == "length": ans += "...\nFor the content length reason, it stopped, continue?" if is_english( [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" return ans, response.usage.completion_tokens except openai.APIError as e: return "**ERROR**: " + str(e), 0 class MoonshotChat(GptTurbo): def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"): if not base_url: base_url="https://api.moonshot.cn/v1" self.client = OpenAI( api_key=key, base_url=base_url) self.model_name = model_name def chat(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) try: response = self.client.chat.completions.create( model=self.model_name, messages=history, **gen_conf) ans = response.choices[0].message.content.strip() if response.choices[0].finish_reason == "length": ans += "...\nFor the content length reason, it stopped, continue?" if is_english( [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" return ans, response.usage.completion_tokens except openai.APIError as e: return "**ERROR**: " + str(e), 0 class QWenChat(Base): def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs): import dashscope dashscope.api_key = key self.model_name = model_name def chat(self, system, history, gen_conf): from http import HTTPStatus if system: history.insert(0, {"role": "system", "content": system}) response = Generation.call( self.model_name, messages=history, result_format='message', **gen_conf ) ans = "" tk_count = 0 if response.status_code == HTTPStatus.OK: ans += response.output.choices[0]['message']['content'] tk_count += response.usage.output_tokens if response.output.choices[0].get("finish_reason", "") == "length": ans += "...\nFor the content length reason, it stopped, continue?" if is_english( [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" return ans, tk_count return "**ERROR**: " + response.message, tk_count class ZhipuChat(Base): def __init__(self, key, model_name="glm-3-turbo", **kwargs): self.client = ZhipuAI(api_key=key) self.model_name = model_name def chat(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) try: response = self.client.chat.completions.create( model=self.model_name, messages=history, **gen_conf ) ans = response.choices[0].message.content.strip() if response.choices[0].finish_reason == "length": ans += "...\nFor the content length reason, it stopped, continue?" if is_english( [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" return ans, response.usage.completion_tokens except Exception as e: return "**ERROR**: " + str(e), 0 class LocalLLM(Base): class RPCProxy: def __init__(self, host, port): self.host = host self.port = int(port) self.__conn() def __conn(self): from multiprocessing.connection import Client self._connection = Client( (self.host, self.port), authkey=b'infiniflow-token4kevinhu') def __getattr__(self, name): import pickle def do_rpc(*args, **kwargs): for _ in range(3): try: self._connection.send( pickle.dumps((name, args, kwargs))) return pickle.loads(self._connection.recv()) except Exception as e: self.__conn() raise Exception("RPC connection lost!") return do_rpc def __init__(self, *args, **kwargs): self.client = LocalLLM.RPCProxy("127.0.0.1", 7860) def chat(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) try: ans = self.client.chat( history, gen_conf ) return ans, num_tokens_from_string(ans) except Exception as e: return "**ERROR**: " + str(e), 0