# # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # from zhipuai import ZhipuAI from dashscope import Generation from abc import ABC from openai import OpenAI import openai from ollama import Client from rag.nlp import is_english class Base(ABC): def __init__(self, key, model_name, base_url): self.client = OpenAI(api_key=key, base_url=base_url) self.model_name = model_name def chat(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) try: response = self.client.chat.completions.create( model=self.model_name, messages=history, **gen_conf) ans = response.choices[0].message.content.strip() if response.choices[0].finish_reason == "length": ans += "...\nFor the content length reason, it stopped, continue?" if is_english( [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" return ans, response.usage.total_tokens except openai.APIError as e: return "**ERROR**: " + str(e), 0 def chat_streamly(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) ans = "" total_tokens = 0 try: response = self.client.chat.completions.create( model=self.model_name, messages=history, stream=True, **gen_conf) for resp in response: if not resp.choices[0].delta.content:continue ans += resp.choices[0].delta.content total_tokens += 1 if resp.choices[0].finish_reason == "length": ans += "...\nFor the content length reason, it stopped, continue?" if is_english( [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" yield ans except openai.APIError as e: yield ans + "\n**ERROR**: " + str(e) yield total_tokens class GptTurbo(Base): def __init__(self, key, model_name="gpt-3.5-turbo", base_url="https://api.openai.com/v1"): if not base_url: base_url="https://api.openai.com/v1" super().__init__(key, model_name, base_url) class MoonshotChat(Base): def __init__(self, key, model_name="moonshot-v1-8k", base_url="https://api.moonshot.cn/v1"): if not base_url: base_url="https://api.moonshot.cn/v1" super().__init__(key, model_name, base_url) class XinferenceChat(Base): def __init__(self, key=None, model_name="", base_url=""): key = "xxx" super().__init__(key, model_name, base_url) class DeepSeekChat(Base): def __init__(self, key, model_name="deepseek-chat", base_url="https://api.deepseek.com/v1"): if not base_url: base_url="https://api.deepseek.com/v1" super().__init__(key, model_name, base_url) class QWenChat(Base): def __init__(self, key, model_name=Generation.Models.qwen_turbo, **kwargs): import dashscope dashscope.api_key = key self.model_name = model_name def chat(self, system, history, gen_conf): from http import HTTPStatus if system: history.insert(0, {"role": "system", "content": system}) response = Generation.call( self.model_name, messages=history, result_format='message', **gen_conf ) ans = "" tk_count = 0 if response.status_code == HTTPStatus.OK: ans += response.output.choices[0]['message']['content'] tk_count += response.usage.total_tokens if response.output.choices[0].get("finish_reason", "") == "length": ans += "...\nFor the content length reason, it stopped, continue?" if is_english( [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" return ans, tk_count return "**ERROR**: " + response.message, tk_count def chat_streamly(self, system, history, gen_conf): from http import HTTPStatus if system: history.insert(0, {"role": "system", "content": system}) ans = "" try: response = Generation.call( self.model_name, messages=history, result_format='message', stream=True, **gen_conf ) tk_count = 0 for resp in response: if resp.status_code == HTTPStatus.OK: ans = resp.output.choices[0]['message']['content'] tk_count = resp.usage.total_tokens if resp.output.choices[0].get("finish_reason", "") == "length": ans += "...\nFor the content length reason, it stopped, continue?" if is_english( [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" yield ans else: yield ans + "\n**ERROR**: " + resp.message if str(resp.message).find("Access")<0 else "Out of credit. Please set the API key in **settings > Model providers.**" except Exception as e: yield ans + "\n**ERROR**: " + str(e) yield tk_count class ZhipuChat(Base): def __init__(self, key, model_name="glm-3-turbo", **kwargs): self.client = ZhipuAI(api_key=key) self.model_name = model_name def chat(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) try: if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"] response = self.client.chat.completions.create( model=self.model_name, messages=history, **gen_conf ) ans = response.choices[0].message.content.strip() if response.choices[0].finish_reason == "length": ans += "...\nFor the content length reason, it stopped, continue?" if is_english( [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" return ans, response.usage.total_tokens except Exception as e: return "**ERROR**: " + str(e), 0 def chat_streamly(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) if "presence_penalty" in gen_conf: del gen_conf["presence_penalty"] if "frequency_penalty" in gen_conf: del gen_conf["frequency_penalty"] ans = "" try: response = self.client.chat.completions.create( model=self.model_name, messages=history, stream=True, **gen_conf ) tk_count = 0 for resp in response: if not resp.choices[0].delta.content:continue delta = resp.choices[0].delta.content ans += delta tk_count = resp.usage.total_tokens if response.usage else 0 if resp.output.choices[0].finish_reason == "length": ans += "...\nFor the content length reason, it stopped, continue?" if is_english( [ans]) else "······\n由于长度的原因,回答被截断了,要继续吗?" yield ans except Exception as e: yield ans + "\n**ERROR**: " + str(e) yield tk_count class OllamaChat(Base): def __init__(self, key, model_name, **kwargs): self.client = Client(host=kwargs["base_url"]) self.model_name = model_name def chat(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) try: options = {} if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"] if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"] if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"] if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] response = self.client.chat( model=self.model_name, messages=history, options=options ) ans = response["message"]["content"].strip() return ans, response["eval_count"] + response.get("prompt_eval_count", 0) except Exception as e: return "**ERROR**: " + str(e), 0 def chat_streamly(self, system, history, gen_conf): if system: history.insert(0, {"role": "system", "content": system}) options = {} if "temperature" in gen_conf: options["temperature"] = gen_conf["temperature"] if "max_tokens" in gen_conf: options["num_predict"] = gen_conf["max_tokens"] if "top_p" in gen_conf: options["top_k"] = gen_conf["top_p"] if "presence_penalty" in gen_conf: options["presence_penalty"] = gen_conf["presence_penalty"] if "frequency_penalty" in gen_conf: options["frequency_penalty"] = gen_conf["frequency_penalty"] ans = "" try: response = self.client.chat( model=self.model_name, messages=history, stream=True, options=options ) for resp in response: if resp["done"]: yield resp.get("prompt_eval_count", 0) + resp.get("eval_count", 0) ans += resp["message"]["content"] yield ans except Exception as e: yield ans + "\n**ERROR**: " + str(e) yield 0