Spaces:
Sleeping
Sleeping
| from typing import List, Dict, Union | |
| from ..embedding_provider import EmbeddingProvider | |
| import numpy as np | |
| class SentenceTransformerEmbedding(EmbeddingProvider): | |
| def __init__( | |
| self, | |
| model_name: str = "sentence-transformers/all-MiniLM-L6-v2", | |
| device: str = None, | |
| batch_size: int = 32, | |
| normalize_embeddings: bool = True | |
| ) -> None: | |
| """Initialize sentence transformer embedding provider | |
| Args: | |
| model_name (str, optional): Name of the sentence tranformer model. Defaults to "sentence-transformers/all-MiniLM-L6-v2". | |
| """ | |
| from sentence_transformers import SentenceTransformer | |
| if device is None: | |
| import torch | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| self.model = SentenceTransformer(model_name, device=device) | |
| self.model_name = model_name | |
| self.batch_size = batch_size | |
| self.normalize_embeddings = normalize_embeddings | |
| def embed_documents(self, documents: List[str]) -> np.ndarray: | |
| """Embed a list of documents | |
| Args: | |
| documents (List[str]): List of documents to embed | |
| """ | |
| return self.model.encode( | |
| documents, | |
| batch_size=self.batch_size, | |
| normalize_embeddings=self.normalize_embeddings | |
| ) | |
| def embed_query(self, query: str) -> np.ndarray: | |
| """Embed a single query | |
| Args: | |
| query (str): Query to embed | |
| Returns: | |
| np.ndarray: Embedding vector | |
| """ | |
| return self.model.encode( | |
| query, | |
| normalize_embeddings=self.normalize_embeddings | |
| ) | |
| def get_model_info(self) -> Dict[str, Union[str, int]]: | |
| """ | |
| Retrieve information about the current embedding model | |
| Returns: | |
| Dict: Model information | |
| """ | |
| return { | |
| "model_name": self.model_name, | |
| "device": self.device, | |
| "batch_size": self.batch_size, | |
| "normalize_embeddings": self.normalize_embeddings, | |
| "embedding_dim": self.model.get_sentence_embedding_dimension() | |
| } | |
| def list_available_models(self) -> List[str]: | |
| """ | |
| List some popular Sentence Transformer models | |
| Returns: | |
| List[str]: Available model names | |
| """ | |
| popular_models = [ | |
| "sentence-transformers/all-MiniLM-L6-v2", # Small and fast | |
| "sentence-transformers/all-mpnet-base-v2", # High performance | |
| "sentence-transformers/all-distilroberta-v1", # Lightweight | |
| "sentence-transformers/multi-qa-MiniLM-L6-cos-v1", # Question Answering | |
| "sentence-transformers/multi-qa-mpnet-base-cos-v1", # Multilingual QA | |
| "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2" # Multilingual | |
| ] | |
| return popular_models | |