Spaces:
Paused
Paused
| # | |
| # Copyright 2024 The InfiniFlow Authors. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| import re | |
| from typing import Optional | |
| import threading | |
| import requests | |
| from huggingface_hub import snapshot_download | |
| from openai.lib.azure import AzureOpenAI | |
| from zhipuai import ZhipuAI | |
| import os | |
| from abc import ABC | |
| from ollama import Client | |
| import dashscope | |
| from openai import OpenAI | |
| from FlagEmbedding import FlagModel | |
| import torch | |
| import numpy as np | |
| import asyncio | |
| from api.utils.file_utils import get_home_cache_dir | |
| from rag.utils import num_tokens_from_string, truncate | |
| import google.generativeai as genai | |
| class Base(ABC): | |
| def __init__(self, key, model_name): | |
| pass | |
| def encode(self, texts: list, batch_size=32): | |
| raise NotImplementedError("Please implement encode method!") | |
| def encode_queries(self, text: str): | |
| raise NotImplementedError("Please implement encode method!") | |
| class DefaultEmbedding(Base): | |
| _model = None | |
| _model_lock = threading.Lock() | |
| def __init__(self, key, model_name, **kwargs): | |
| """ | |
| If you have trouble downloading HuggingFace models, -_^ this might help!! | |
| For Linux: | |
| export HF_ENDPOINT=https://hf-mirror.com | |
| For Windows: | |
| Good luck | |
| ^_- | |
| """ | |
| if not DefaultEmbedding._model: | |
| with DefaultEmbedding._model_lock: | |
| if not DefaultEmbedding._model: | |
| try: | |
| DefaultEmbedding._model = FlagModel(os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), | |
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |
| use_fp16=torch.cuda.is_available()) | |
| except Exception as e: | |
| model_dir = snapshot_download(repo_id="BAAI/bge-large-zh-v1.5", | |
| local_dir=os.path.join(get_home_cache_dir(), re.sub(r"^[a-zA-Z]+/", "", model_name)), | |
| local_dir_use_symlinks=False) | |
| DefaultEmbedding._model = FlagModel(model_dir, | |
| query_instruction_for_retrieval="为这个句子生成表示以用于检索相关文章:", | |
| use_fp16=torch.cuda.is_available()) | |
| self._model = DefaultEmbedding._model | |
| def encode(self, texts: list, batch_size=32): | |
| texts = [truncate(t, 2048) for t in texts] | |
| token_count = 0 | |
| for t in texts: | |
| token_count += num_tokens_from_string(t) | |
| res = [] | |
| for i in range(0, len(texts), batch_size): | |
| res.extend(self._model.encode(texts[i:i + batch_size]).tolist()) | |
| return np.array(res), token_count | |
| def encode_queries(self, text: str): | |
| token_count = num_tokens_from_string(text) | |
| return self._model.encode_queries([text]).tolist()[0], token_count | |
| class OpenAIEmbed(Base): | |
| def __init__(self, key, model_name="text-embedding-ada-002", | |
| base_url="https://api.openai.com/v1"): | |
| if not base_url: | |
| base_url = "https://api.openai.com/v1" | |
| self.client = OpenAI(api_key=key, base_url=base_url) | |
| self.model_name = model_name | |
| def encode(self, texts: list, batch_size=32): | |
| texts = [truncate(t, 8196) for t in texts] | |
| res = self.client.embeddings.create(input=texts, | |
| model=self.model_name) | |
| return np.array([d.embedding for d in res.data] | |
| ), res.usage.total_tokens | |
| def encode_queries(self, text): | |
| res = self.client.embeddings.create(input=[truncate(text, 8196)], | |
| model=self.model_name) | |
| return np.array(res.data[0].embedding), res.usage.total_tokens | |
| class LocalAIEmbed(Base): | |
| def __init__(self, key, model_name, base_url): | |
| if not base_url: | |
| raise ValueError("Local embedding model url cannot be None") | |
| if base_url.split("/")[-1] != "v1": | |
| base_url = os.path.join(base_url, "v1") | |
| self.client = OpenAI(api_key="empty", base_url=base_url) | |
| self.model_name = model_name.split("___")[0] | |
| def encode(self, texts: list, batch_size=32): | |
| res = self.client.embeddings.create(input=texts, model=self.model_name) | |
| return ( | |
| np.array([d.embedding for d in res.data]), | |
| 1024, | |
| ) # local embedding for LmStudio donot count tokens | |
| def encode_queries(self, text): | |
| embds, cnt = self.encode([text]) | |
| return np.array(embds[0]), cnt | |
| class AzureEmbed(OpenAIEmbed): | |
| def __init__(self, key, model_name, **kwargs): | |
| self.client = AzureOpenAI(api_key=key, azure_endpoint=kwargs["base_url"], api_version="2024-02-01") | |
| self.model_name = model_name | |
| class BaiChuanEmbed(OpenAIEmbed): | |
| def __init__(self, key, | |
| model_name='Baichuan-Text-Embedding', | |
| base_url='https://api.baichuan-ai.com/v1'): | |
| if not base_url: | |
| base_url = "https://api.baichuan-ai.com/v1" | |
| super().__init__(key, model_name, base_url) | |
| class QWenEmbed(Base): | |
| def __init__(self, key, model_name="text_embedding_v2", **kwargs): | |
| dashscope.api_key = key | |
| self.model_name = model_name | |
| def encode(self, texts: list, batch_size=10): | |
| import dashscope | |
| try: | |
| res = [] | |
| token_count = 0 | |
| texts = [truncate(t, 2048) for t in texts] | |
| for i in range(0, len(texts), batch_size): | |
| resp = dashscope.TextEmbedding.call( | |
| model=self.model_name, | |
| input=texts[i:i + batch_size], | |
| text_type="document" | |
| ) | |
| embds = [[] for _ in range(len(resp["output"]["embeddings"]))] | |
| for e in resp["output"]["embeddings"]: | |
| embds[e["text_index"]] = e["embedding"] | |
| res.extend(embds) | |
| token_count += resp["usage"]["total_tokens"] | |
| return np.array(res), token_count | |
| except Exception as e: | |
| raise Exception("Account abnormal. Please ensure it's on good standing to use QWen's "+self.model_name) | |
| return np.array([]), 0 | |
| def encode_queries(self, text): | |
| try: | |
| resp = dashscope.TextEmbedding.call( | |
| model=self.model_name, | |
| input=text[:2048], | |
| text_type="query" | |
| ) | |
| return np.array(resp["output"]["embeddings"][0] | |
| ["embedding"]), resp["usage"]["total_tokens"] | |
| except Exception as e: | |
| raise Exception("Account abnormal. Please ensure it's on good standing to use QWen's "+self.model_name) | |
| return np.array([]), 0 | |
| class ZhipuEmbed(Base): | |
| def __init__(self, key, model_name="embedding-2", **kwargs): | |
| self.client = ZhipuAI(api_key=key) | |
| self.model_name = model_name | |
| def encode(self, texts: list, batch_size=32): | |
| arr = [] | |
| tks_num = 0 | |
| for txt in texts: | |
| res = self.client.embeddings.create(input=txt, | |
| model=self.model_name) | |
| arr.append(res.data[0].embedding) | |
| tks_num += res.usage.total_tokens | |
| return np.array(arr), tks_num | |
| def encode_queries(self, text): | |
| res = self.client.embeddings.create(input=text, | |
| model=self.model_name) | |
| return np.array(res.data[0].embedding), res.usage.total_tokens | |
| class OllamaEmbed(Base): | |
| def __init__(self, key, model_name, **kwargs): | |
| self.client = Client(host=kwargs["base_url"]) | |
| self.model_name = model_name | |
| def encode(self, texts: list, batch_size=32): | |
| arr = [] | |
| tks_num = 0 | |
| for txt in texts: | |
| res = self.client.embeddings(prompt=txt, | |
| model=self.model_name) | |
| arr.append(res["embedding"]) | |
| tks_num += 128 | |
| return np.array(arr), tks_num | |
| def encode_queries(self, text): | |
| res = self.client.embeddings(prompt=text, | |
| model=self.model_name) | |
| return np.array(res["embedding"]), 128 | |
| class FastEmbed(Base): | |
| _model = None | |
| def __init__( | |
| self, | |
| key: Optional[str] = None, | |
| model_name: str = "BAAI/bge-small-en-v1.5", | |
| cache_dir: Optional[str] = None, | |
| threads: Optional[int] = None, | |
| **kwargs, | |
| ): | |
| from fastembed import TextEmbedding | |
| if not FastEmbed._model: | |
| self._model = TextEmbedding(model_name, cache_dir, threads, **kwargs) | |
| def encode(self, texts: list, batch_size=32): | |
| # Using the internal tokenizer to encode the texts and get the total | |
| # number of tokens | |
| encodings = self._model.model.tokenizer.encode_batch(texts) | |
| total_tokens = sum(len(e) for e in encodings) | |
| embeddings = [e.tolist() for e in self._model.embed(texts, batch_size)] | |
| return np.array(embeddings), total_tokens | |
| def encode_queries(self, text: str): | |
| # Using the internal tokenizer to encode the texts and get the total | |
| # number of tokens | |
| encoding = self._model.model.tokenizer.encode(text) | |
| embedding = next(self._model.query_embed(text)).tolist() | |
| return np.array(embedding), len(encoding.ids) | |
| class XinferenceEmbed(Base): | |
| def __init__(self, key, model_name="", base_url=""): | |
| self.client = OpenAI(api_key="xxx", base_url=base_url) | |
| self.model_name = model_name | |
| def encode(self, texts: list, batch_size=32): | |
| res = self.client.embeddings.create(input=texts, | |
| model=self.model_name) | |
| return np.array([d.embedding for d in res.data] | |
| ), res.usage.total_tokens | |
| def encode_queries(self, text): | |
| res = self.client.embeddings.create(input=[text], | |
| model=self.model_name) | |
| return np.array(res.data[0].embedding), res.usage.total_tokens | |
| class YoudaoEmbed(Base): | |
| _client = None | |
| def __init__(self, key=None, model_name="maidalun1020/bce-embedding-base_v1", **kwargs): | |
| from BCEmbedding import EmbeddingModel as qanthing | |
| if not YoudaoEmbed._client: | |
| try: | |
| print("LOADING BCE...") | |
| YoudaoEmbed._client = qanthing(model_name_or_path=os.path.join( | |
| get_home_cache_dir(), | |
| "bce-embedding-base_v1")) | |
| except Exception as e: | |
| YoudaoEmbed._client = qanthing( | |
| model_name_or_path=model_name.replace( | |
| "maidalun1020", "InfiniFlow")) | |
| def encode(self, texts: list, batch_size=10): | |
| res = [] | |
| token_count = 0 | |
| for t in texts: | |
| token_count += num_tokens_from_string(t) | |
| for i in range(0, len(texts), batch_size): | |
| embds = YoudaoEmbed._client.encode(texts[i:i + batch_size]) | |
| res.extend(embds) | |
| return np.array(res), token_count | |
| def encode_queries(self, text): | |
| embds = YoudaoEmbed._client.encode([text]) | |
| return np.array(embds[0]), num_tokens_from_string(text) | |
| class JinaEmbed(Base): | |
| def __init__(self, key, model_name="jina-embeddings-v2-base-zh", | |
| base_url="https://api.jina.ai/v1/embeddings"): | |
| self.base_url = "https://api.jina.ai/v1/embeddings" | |
| self.headers = { | |
| "Content-Type": "application/json", | |
| "Authorization": f"Bearer {key}" | |
| } | |
| self.model_name = model_name | |
| def encode(self, texts: list, batch_size=None): | |
| texts = [truncate(t, 8196) for t in texts] | |
| data = { | |
| "model": self.model_name, | |
| "input": texts, | |
| 'encoding_type': 'float' | |
| } | |
| res = requests.post(self.base_url, headers=self.headers, json=data).json() | |
| return np.array([d["embedding"] for d in res["data"]]), res["usage"]["total_tokens"] | |
| def encode_queries(self, text): | |
| embds, cnt = self.encode([text]) | |
| return np.array(embds[0]), cnt | |
| class InfinityEmbed(Base): | |
| _model = None | |
| def __init__( | |
| self, | |
| model_names: list[str] = ("BAAI/bge-small-en-v1.5",), | |
| engine_kwargs: dict = {}, | |
| key = None, | |
| ): | |
| from infinity_emb import EngineArgs | |
| from infinity_emb.engine import AsyncEngineArray | |
| self._default_model = model_names[0] | |
| self.engine_array = AsyncEngineArray.from_args([EngineArgs(model_name_or_path = model_name, **engine_kwargs) for model_name in model_names]) | |
| async def _embed(self, sentences: list[str], model_name: str = ""): | |
| if not model_name: | |
| model_name = self._default_model | |
| engine = self.engine_array[model_name] | |
| was_already_running = engine.is_running | |
| if not was_already_running: | |
| await engine.astart() | |
| embeddings, usage = await engine.embed(sentences=sentences) | |
| if not was_already_running: | |
| await engine.astop() | |
| return embeddings, usage | |
| def encode(self, texts: list[str], model_name: str = "") -> tuple[np.ndarray, int]: | |
| # Using the internal tokenizer to encode the texts and get the total | |
| # number of tokens | |
| embeddings, usage = asyncio.run(self._embed(texts, model_name)) | |
| return np.array(embeddings), usage | |
| def encode_queries(self, text: str) -> tuple[np.ndarray, int]: | |
| # Using the internal tokenizer to encode the texts and get the total | |
| # number of tokens | |
| return self.encode([text]) | |
| class MistralEmbed(Base): | |
| def __init__(self, key, model_name="mistral-embed", | |
| base_url=None): | |
| from mistralai.client import MistralClient | |
| self.client = MistralClient(api_key=key) | |
| self.model_name = model_name | |
| def encode(self, texts: list, batch_size=32): | |
| texts = [truncate(t, 8196) for t in texts] | |
| res = self.client.embeddings(input=texts, | |
| model=self.model_name) | |
| return np.array([d.embedding for d in res.data] | |
| ), res.usage.total_tokens | |
| def encode_queries(self, text): | |
| res = self.client.embeddings(input=[truncate(text, 8196)], | |
| model=self.model_name) | |
| return np.array(res.data[0].embedding), res.usage.total_tokens | |
| class BedrockEmbed(Base): | |
| def __init__(self, key, model_name, | |
| **kwargs): | |
| import boto3 | |
| self.bedrock_ak = eval(key).get('bedrock_ak', '') | |
| self.bedrock_sk = eval(key).get('bedrock_sk', '') | |
| self.bedrock_region = eval(key).get('bedrock_region', '') | |
| self.model_name = model_name | |
| self.client = boto3.client(service_name='bedrock-runtime', region_name=self.bedrock_region, | |
| aws_access_key_id=self.bedrock_ak, aws_secret_access_key=self.bedrock_sk) | |
| def encode(self, texts: list, batch_size=32): | |
| texts = [truncate(t, 8196) for t in texts] | |
| embeddings = [] | |
| token_count = 0 | |
| for text in texts: | |
| if self.model_name.split('.')[0] == 'amazon': | |
| body = {"inputText": text} | |
| elif self.model_name.split('.')[0] == 'cohere': | |
| body = {"texts": [text], "input_type": 'search_document'} | |
| response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) | |
| model_response = json.loads(response["body"].read()) | |
| embeddings.extend([model_response["embedding"]]) | |
| token_count += num_tokens_from_string(text) | |
| return np.array(embeddings), token_count | |
| def encode_queries(self, text): | |
| embeddings = [] | |
| token_count = num_tokens_from_string(text) | |
| if self.model_name.split('.')[0] == 'amazon': | |
| body = {"inputText": truncate(text, 8196)} | |
| elif self.model_name.split('.')[0] == 'cohere': | |
| body = {"texts": [truncate(text, 8196)], "input_type": 'search_query'} | |
| response = self.client.invoke_model(modelId=self.model_name, body=json.dumps(body)) | |
| model_response = json.loads(response["body"].read()) | |
| embeddings.extend([model_response["embedding"]]) | |
| return np.array(embeddings), token_count | |
| class GeminiEmbed(Base): | |
| def __init__(self, key, model_name='models/text-embedding-004', | |
| **kwargs): | |
| genai.configure(api_key=key) | |
| self.model_name = 'models/' + model_name | |
| def encode(self, texts: list, batch_size=32): | |
| texts = [truncate(t, 2048) for t in texts] | |
| token_count = sum(num_tokens_from_string(text) for text in texts) | |
| result = genai.embed_content( | |
| model=self.model_name, | |
| content=texts, | |
| task_type="retrieval_document", | |
| title="Embedding of list of strings") | |
| return np.array(result['embedding']),token_count | |
| def encode_queries(self, text): | |
| result = genai.embed_content( | |
| model=self.model_name, | |
| content=truncate(text,2048), | |
| task_type="retrieval_document", | |
| title="Embedding of single string") | |
| token_count = num_tokens_from_string(text) | |
| return np.array(result['embedding']),token_count | |
| class NvidiaEmbed(Base): | |
| def __init__( | |
| self, key, model_name, base_url="https://integrate.api.nvidia.com/v1/embeddings" | |
| ): | |
| if not base_url: | |
| base_url = "https://integrate.api.nvidia.com/v1/embeddings" | |
| self.api_key = key | |
| self.base_url = base_url | |
| self.headers = { | |
| "accept": "application/json", | |
| "Content-Type": "application/json", | |
| "authorization": f"Bearer {self.api_key}", | |
| } | |
| self.model_name = model_name | |
| if model_name == "nvidia/embed-qa-4": | |
| self.base_url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings" | |
| self.model_name = "NV-Embed-QA" | |
| if model_name == "snowflake/arctic-embed-l": | |
| self.base_url = "https://ai.api.nvidia.com/v1/retrieval/snowflake/arctic-embed-l/embeddings" | |
| def encode(self, texts: list, batch_size=None): | |
| payload = { | |
| "input": texts, | |
| "input_type": "query", | |
| "model": self.model_name, | |
| "encoding_format": "float", | |
| "truncate": "END", | |
| } | |
| res = requests.post(self.base_url, headers=self.headers, json=payload).json() | |
| return ( | |
| np.array([d["embedding"] for d in res["data"]]), | |
| res["usage"]["total_tokens"], | |
| ) | |
| def encode_queries(self, text): | |
| embds, cnt = self.encode([text]) | |
| return np.array(embds[0]), cnt | |
| class LmStudioEmbed(LocalAIEmbed): | |
| def __init__(self, key, model_name, base_url): | |
| if not base_url: | |
| raise ValueError("Local llm url cannot be None") | |
| if base_url.split("/")[-1] != "v1": | |
| base_url = os.path.join(base_url, "v1") | |
| self.client = OpenAI(api_key="lm-studio", base_url=base_url) | |
| self.model_name = model_name | |
| class OpenAI_APIEmbed(OpenAIEmbed): | |
| def __init__(self, key, model_name, base_url): | |
| if not base_url: | |
| raise ValueError("url cannot be None") | |
| if base_url.split("/")[-1] != "v1": | |
| base_url = os.path.join(base_url, "v1") | |
| self.client = OpenAI(api_key=key, base_url=base_url) | |
| self.model_name = model_name.split("___")[0] | |
| class CoHereEmbed(Base): | |
| def __init__(self, key, model_name, base_url=None): | |
| from cohere import Client | |
| self.client = Client(api_key=key) | |
| self.model_name = model_name | |
| def encode(self, texts: list, batch_size=32): | |
| res = self.client.embed( | |
| texts=texts, | |
| model=self.model_name, | |
| input_type="search_query", | |
| embedding_types=["float"], | |
| ) | |
| return np.array([d for d in res.embeddings.float]), int( | |
| res.meta.billed_units.input_tokens | |
| ) | |
| def encode_queries(self, text): | |
| res = self.client.embed( | |
| texts=[text], | |
| model=self.model_name, | |
| input_type="search_query", | |
| embedding_types=["float"], | |
| ) | |
| return np.array([d for d in res.embeddings.float]), int( | |
| res.meta.billed_units.input_tokens | |
| ) | |