Spaces:
Runtime error
Runtime error
"""OpenAI embeddings file.""" | |
from enum import Enum | |
from typing import List, Optional | |
import openai | |
from tenacity import retry, stop_after_attempt, wait_random_exponential | |
from gpt_index.embeddings.base import BaseEmbedding | |
class OpenAIEmbeddingMode(str, Enum): | |
"""OpenAI embedding mode.""" | |
SIMILARITY_MODE = "similarity" | |
TEXT_SEARCH_MODE = "text_search" | |
class OpenAIEmbeddingModelType(str, Enum): | |
"""OpenAI embedding model type.""" | |
DAVINCI = "davinci" | |
CURIE = "curie" | |
BABBAGE = "babbage" | |
ADA = "ada" | |
TEXT_EMBED_ADA_002 = "text-embedding-ada-002" | |
class OpenAIEmbeddingModeModel(str, Enum): | |
"""OpenAI embedding mode model.""" | |
# davinci | |
TEXT_SIMILARITY_DAVINCI = "text-similarity-davinci-001" | |
TEXT_SEARCH_DAVINCI_QUERY = "text-search-davinci-query-001" | |
TEXT_SEARCH_DAVINCI_DOC = "text-search-davinci-doc-001" | |
# curie | |
TEXT_SIMILARITY_CURIE = "text-similarity-curie-001" | |
TEXT_SEARCH_CURIE_QUERY = "text-search-curie-query-001" | |
TEXT_SEARCH_CURIE_DOC = "text-search-curie-doc-001" | |
# babbage | |
TEXT_SIMILARITY_BABBAGE = "text-similarity-babbage-001" | |
TEXT_SEARCH_BABBAGE_QUERY = "text-search-babbage-query-001" | |
TEXT_SEARCH_BABBAGE_DOC = "text-search-babbage-doc-001" | |
# ada | |
TEXT_SIMILARITY_ADA = "text-similarity-ada-001" | |
TEXT_SEARCH_ADA_QUERY = "text-search-ada-query-001" | |
TEXT_SEARCH_ADA_DOC = "text-search-ada-doc-001" | |
# text-embedding-ada-002 | |
TEXT_EMBED_ADA_002 = "text-embedding-ada-002" | |
# convenient shorthand | |
OAEM = OpenAIEmbeddingMode | |
OAEMT = OpenAIEmbeddingModelType | |
OAEMM = OpenAIEmbeddingModeModel | |
EMBED_MAX_TOKEN_LIMIT = 2048 | |
_QUERY_MODE_MODEL_DICT = { | |
(OAEM.SIMILARITY_MODE, "davinci"): OAEMM.TEXT_SIMILARITY_DAVINCI, | |
(OAEM.SIMILARITY_MODE, "curie"): OAEMM.TEXT_SIMILARITY_CURIE, | |
(OAEM.SIMILARITY_MODE, "babbage"): OAEMM.TEXT_SIMILARITY_BABBAGE, | |
(OAEM.SIMILARITY_MODE, "ada"): OAEMM.TEXT_SIMILARITY_ADA, | |
(OAEM.SIMILARITY_MODE, "text-embedding-ada-002"): OAEMM.TEXT_EMBED_ADA_002, | |
(OAEM.TEXT_SEARCH_MODE, "davinci"): OAEMM.TEXT_SEARCH_DAVINCI_QUERY, | |
(OAEM.TEXT_SEARCH_MODE, "curie"): OAEMM.TEXT_SEARCH_CURIE_QUERY, | |
(OAEM.TEXT_SEARCH_MODE, "babbage"): OAEMM.TEXT_SEARCH_BABBAGE_QUERY, | |
(OAEM.TEXT_SEARCH_MODE, "ada"): OAEMM.TEXT_SEARCH_ADA_QUERY, | |
(OAEM.TEXT_SEARCH_MODE, "text-embedding-ada-002"): OAEMM.TEXT_EMBED_ADA_002, | |
} | |
_TEXT_MODE_MODEL_DICT = { | |
(OAEM.SIMILARITY_MODE, "davinci"): OAEMM.TEXT_SIMILARITY_DAVINCI, | |
(OAEM.SIMILARITY_MODE, "curie"): OAEMM.TEXT_SIMILARITY_CURIE, | |
(OAEM.SIMILARITY_MODE, "babbage"): OAEMM.TEXT_SIMILARITY_BABBAGE, | |
(OAEM.SIMILARITY_MODE, "ada"): OAEMM.TEXT_SIMILARITY_ADA, | |
(OAEM.SIMILARITY_MODE, "text-embedding-ada-002"): OAEMM.TEXT_EMBED_ADA_002, | |
(OAEM.TEXT_SEARCH_MODE, "davinci"): OAEMM.TEXT_SEARCH_DAVINCI_DOC, | |
(OAEM.TEXT_SEARCH_MODE, "curie"): OAEMM.TEXT_SEARCH_CURIE_DOC, | |
(OAEM.TEXT_SEARCH_MODE, "babbage"): OAEMM.TEXT_SEARCH_BABBAGE_DOC, | |
(OAEM.TEXT_SEARCH_MODE, "ada"): OAEMM.TEXT_SEARCH_ADA_DOC, | |
(OAEM.TEXT_SEARCH_MODE, "text-embedding-ada-002"): OAEMM.TEXT_EMBED_ADA_002, | |
} | |
def get_embedding( | |
text: str, | |
engine: Optional[str] = None, | |
) -> List[float]: | |
"""Get embedding. | |
NOTE: Copied from OpenAI's embedding utils: | |
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py | |
Copied here to avoid importing unnecessary dependencies | |
like matplotlib, plotly, scipy, sklearn. | |
""" | |
text = text.replace("\n", " ") | |
return openai.Embedding.create(input=[text], engine=engine)["data"][0]["embedding"] | |
async def aget_embedding(text: str, engine: Optional[str] = None) -> List[float]: | |
"""Asynchronously get embedding. | |
NOTE: Copied from OpenAI's embedding utils: | |
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py | |
Copied here to avoid importing unnecessary dependencies | |
like matplotlib, plotly, scipy, sklearn. | |
""" | |
# replace newlines, which can negatively affect performance. | |
text = text.replace("\n", " ") | |
return (await openai.Embedding.acreate(input=[text], engine=engine))["data"][0][ | |
"embedding" | |
] | |
def get_embeddings( | |
list_of_text: List[str], | |
engine: Optional[str] = None, | |
) -> List[List[float]]: | |
"""Get embeddings. | |
NOTE: Copied from OpenAI's embedding utils: | |
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py | |
Copied here to avoid importing unnecessary dependencies | |
like matplotlib, plotly, scipy, sklearn. | |
""" | |
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." | |
# replace newlines, which can negatively affect performance. | |
list_of_text = [text.replace("\n", " ") for text in list_of_text] | |
data = openai.Embedding.create(input=list_of_text, engine=engine).data | |
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. | |
return [d["embedding"] for d in data] | |
async def aget_embeddings( | |
list_of_text: List[str], engine: Optional[str] = None | |
) -> List[List[float]]: | |
"""Asynchronously get embeddings. | |
NOTE: Copied from OpenAI's embedding utils: | |
https://github.com/openai/openai-python/blob/main/openai/embeddings_utils.py | |
Copied here to avoid importing unnecessary dependencies | |
like matplotlib, plotly, scipy, sklearn. | |
""" | |
assert len(list_of_text) <= 2048, "The batch size should not be larger than 2048." | |
# replace newlines, which can negatively affect performance. | |
list_of_text = [text.replace("\n", " ") for text in list_of_text] | |
data = (await openai.Embedding.acreate(input=list_of_text, engine=engine)).data | |
data = sorted(data, key=lambda x: x["index"]) # maintain the same order as input. | |
return [d["embedding"] for d in data] | |
class OpenAIEmbedding(BaseEmbedding): | |
"""OpenAI class for embeddings. | |
Args: | |
mode (str): Mode for embedding. | |
Defaults to OpenAIEmbeddingMode.TEXT_SEARCH_MODE. | |
Options are: | |
- OpenAIEmbeddingMode.SIMILARITY_MODE | |
- OpenAIEmbeddingMode.TEXT_SEARCH_MODE | |
model (str): Model for embedding. | |
Defaults to OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002. | |
Options are: | |
- OpenAIEmbeddingModelType.DAVINCI | |
- OpenAIEmbeddingModelType.CURIE | |
- OpenAIEmbeddingModelType.BABBAGE | |
- OpenAIEmbeddingModelType.ADA | |
- OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002 | |
deployment_name (Optional[str]): Optional deployment of model. Defaults to None. | |
If this value is not None, mode and model will be ignored. | |
Only available for using AzureOpenAI. | |
""" | |
def __init__( | |
self, | |
mode: str = OpenAIEmbeddingMode.TEXT_SEARCH_MODE, | |
model: str = OpenAIEmbeddingModelType.TEXT_EMBED_ADA_002, | |
deployment_name: Optional[str] = None, | |
) -> None: | |
"""Init params.""" | |
super().__init__() | |
self.mode = OpenAIEmbeddingMode(mode) | |
self.model = OpenAIEmbeddingModelType(model) | |
self.deployment_name = deployment_name | |
def _get_query_embedding(self, query: str) -> List[float]: | |
"""Get query embedding.""" | |
if self.deployment_name is not None: | |
engine = self.deployment_name | |
else: | |
key = (self.mode, self.model) | |
if key not in _QUERY_MODE_MODEL_DICT: | |
raise ValueError(f"Invalid mode, model combination: {key}") | |
engine = _QUERY_MODE_MODEL_DICT[key] | |
return get_embedding(query, engine=engine) | |
def _get_text_embedding(self, text: str) -> List[float]: | |
"""Get text embedding.""" | |
if self.deployment_name is not None: | |
engine = self.deployment_name | |
else: | |
key = (self.mode, self.model) | |
if key not in _TEXT_MODE_MODEL_DICT: | |
raise ValueError(f"Invalid mode, model combination: {key}") | |
engine = _TEXT_MODE_MODEL_DICT[key] | |
return get_embedding(text, engine=engine) | |
async def _aget_text_embedding(self, text: str) -> List[float]: | |
"""Asynchronously get text embedding.""" | |
if self.deployment_name is not None: | |
engine = self.deployment_name | |
else: | |
key = (self.mode, self.model) | |
if key not in _TEXT_MODE_MODEL_DICT: | |
raise ValueError(f"Invalid mode, model combination: {key}") | |
engine = _TEXT_MODE_MODEL_DICT[key] | |
return await aget_embedding(text, engine=engine) | |
def _get_text_embeddings(self, texts: List[str]) -> List[List[float]]: | |
"""Get text embeddings. | |
By default, this is a wrapper around _get_text_embedding. | |
Can be overriden for batch queries. | |
""" | |
if self.deployment_name is not None: | |
engine = self.deployment_name | |
else: | |
key = (self.mode, self.model) | |
if key not in _TEXT_MODE_MODEL_DICT: | |
raise ValueError(f"Invalid mode, model combination: {key}") | |
engine = _TEXT_MODE_MODEL_DICT[key] | |
embeddings = get_embeddings(texts, engine=engine) | |
return embeddings | |
async def _aget_text_embeddings(self, texts: List[str]) -> List[List[float]]: | |
"""Asynchronously get text embeddings.""" | |
if self.deployment_name is not None: | |
engine = self.deployment_name | |
else: | |
key = (self.mode, self.model) | |
if key not in _TEXT_MODE_MODEL_DICT: | |
raise ValueError(f"Invalid mode, model combination: {key}") | |
engine = _TEXT_MODE_MODEL_DICT[key] | |
embeddings = await aget_embeddings(texts, engine=engine) | |
return embeddings | |