Shreyas094's picture
Upload 528 files
372531f verified
import os
from typing import Any
OPENAI_EMBEDDING_MODEL = os.environ.get(
"OPENAI_EMBEDDING_MODEL", "text-embedding-3-small"
)
_SUPPORTED_PROVIDERS = {
"openai",
"azure_openai",
"cohere",
"google_vertexai",
"google_genai",
"fireworks",
"ollama",
"together",
"mistralai",
"huggingface",
"nomic",
"voyageai",
"dashscope",
"custom",
"bedrock",
}
class Memory:
def __init__(self, embedding_provider: str, model: str, **embdding_kwargs: Any):
_embeddings = None
match embedding_provider:
case "custom":
from langchain_openai import OpenAIEmbeddings
_embeddings = OpenAIEmbeddings(
model=model,
openai_api_key=os.getenv("OPENAI_API_KEY", "custom"),
openai_api_base=os.getenv(
"OPENAI_BASE_URL", "http://localhost:1234/v1"
), # default for lmstudio
check_embedding_ctx_length=False,
**embdding_kwargs,
) # quick fix for lmstudio
case "openai":
from langchain_openai import OpenAIEmbeddings
_embeddings = OpenAIEmbeddings(model=model, **embdding_kwargs)
case "azure_openai":
from langchain_openai import AzureOpenAIEmbeddings
_embeddings = AzureOpenAIEmbeddings(
model=model,
azure_endpoint=os.environ["AZURE_OPENAI_ENDPOINT"],
openai_api_key=os.environ["AZURE_OPENAI_API_KEY"],
openai_api_version=os.environ["AZURE_OPENAI_API_VERSION"],
**embdding_kwargs,
)
case "cohere":
from langchain_cohere import CohereEmbeddings
_embeddings = CohereEmbeddings(model=model, **embdding_kwargs)
case "google_vertexai":
from langchain_google_vertexai import VertexAIEmbeddings
_embeddings = VertexAIEmbeddings(model=model, **embdding_kwargs)
case "google_genai":
from langchain_google_genai import GoogleGenerativeAIEmbeddings
_embeddings = GoogleGenerativeAIEmbeddings(
model=model, **embdding_kwargs
)
case "fireworks":
from langchain_fireworks import FireworksEmbeddings
_embeddings = FireworksEmbeddings(model=model, **embdding_kwargs)
case "ollama":
from langchain_ollama import OllamaEmbeddings
_embeddings = OllamaEmbeddings(
model=model,
base_url=os.environ["OLLAMA_BASE_URL"],
**embdding_kwargs,
)
case "together":
from langchain_together import TogetherEmbeddings
_embeddings = TogetherEmbeddings(model=model, **embdding_kwargs)
case "mistralai":
from langchain_mistralai import MistralAIEmbeddings
_embeddings = MistralAIEmbeddings(model=model, **embdding_kwargs)
case "huggingface":
from langchain_huggingface import HuggingFaceEmbeddings
_embeddings = HuggingFaceEmbeddings(model_name=model, **embdding_kwargs)
case "nomic":
from langchain_nomic import NomicEmbeddings
_embeddings = NomicEmbeddings(model=model, **embdding_kwargs)
case "voyageai":
from langchain_voyageai import VoyageAIEmbeddings
_embeddings = VoyageAIEmbeddings(
voyage_api_key=os.environ["VOYAGE_API_KEY"],
model=model,
**embdding_kwargs,
)
case "dashscope":
from langchain_community.embeddings import DashScopeEmbeddings
_embeddings = DashScopeEmbeddings(model=model, **embdding_kwargs)
case "bedrock":
from langchain_aws.embeddings import BedrockEmbeddings
_embeddings = BedrockEmbeddings(model_id=model, **embdding_kwargs)
case _:
raise Exception("Embedding not found.")
self._embeddings = _embeddings
def get_embeddings(self):
return self._embeddings