Spaces:
Sleeping
Sleeping
import json | |
import logging | |
import time | |
from pathlib import Path | |
from typing import Any, Callable | |
import modal | |
import numpy as np | |
import numpy.typing as npt | |
import torch | |
from sentence_transformers import SentenceTransformer | |
from article_embedding.embed import SentenceTransformerModel, StellaEmbedder | |
log = logging.getLogger(__name__) | |
def load_model() -> SentenceTransformer: | |
return SentenceTransformer( | |
"dunzhang/stella_en_400M_v5", | |
trust_remote_code=True, | |
config_kwargs={"use_memory_efficient_attention": False, "unpad_inputs": False}, | |
) | |
image = modal.Image.debian_slim(python_version="3.12").pip_install(["sentence-transformers"]).run_function(load_model) | |
app = modal.App("embedding", image=image) | |
class ModalEmbedder: | |
def setup(self) -> None: | |
logging.basicConfig(level=logging.WARN) | |
log.setLevel(logging.DEBUG) | |
if torch.cuda.is_available(): | |
device = "cuda" | |
elif torch.backends.mps.is_available(): | |
device = "mps" | |
else: | |
device = "cpu" | |
self.model = load_model().to(device) | |
log.info("Model loaded on %s", device) | |
def embed(self, documents: list[str]) -> Any: | |
return self.model.encode(documents) | |
async def fetch_documents() -> list[str]: | |
from article_embedding.couchdb import CouchDB | |
async for sorted_batch in CouchDB().changes(batch_size=256): | |
sorted_batch = sorted(sorted_batch, key=lambda x: x.get("_id")) | |
return [a["content"] for a in sorted_batch if a.get("type") == "article" and a.get("language") == "en"] | |
return [] | |
def process(func: Callable[[list[str]], Any], documents: list[str], name: str) -> None: | |
func(["Hello, world!"]) # Warmup | |
ts0 = time.time() | |
embeddings = func(documents) | |
benchmark = time.time() - ts0 | |
output_path = Path("data/embeddings.json") | |
golden_path = Path(f"data/embeddings.{name}-golden.json") | |
save_embeddings(embeddings, output_path) | |
cosine_distance, rms = compare_embeddings(embeddings, golden_path) | |
log.info( | |
"%s - MCS: %.2f. RMS: %.2f. Latency: %.2f ms. Size: %d", | |
name, | |
cosine_distance, | |
rms, | |
benchmark / len(embeddings) * 1000, | |
len(embeddings[0]), | |
) | |
def process2(model: SentenceTransformerModel, documents: list[str], name: str) -> None: | |
model.embed(["Hello, world!"]) # Warmup | |
ts0 = time.time() | |
embeddings = model.embed(documents) | |
benchmark = time.time() - ts0 | |
output_path = Path("data/embeddings.json") | |
save_embeddings(embeddings, output_path) | |
golden_path = Path(f"data/embeddings.{name}-golden.json") | |
if golden_path.exists(): | |
golden_embeddings: Any = load_embeddings(golden_path) | |
similarities = model.model.similarity_pairwise(embeddings, golden_embeddings) | |
rms = torch.sqrt(torch.mean(similarities**2)) | |
else: | |
save_embeddings(embeddings, golden_path) | |
rms = torch.zeros([0]) | |
log.info( | |
"%s - RMS: %.2f. Latency: %.2f ms. Size: %d", | |
name, | |
rms, | |
benchmark / len(embeddings) * 1000, | |
len(embeddings[0]), | |
) | |
def load_embeddings(path: Path) -> list[npt.NDArray[np.float64]]: | |
with path.open() as f: | |
return [np.array(json.loads(line)) for line in f.readlines()] | |
def save_embeddings(embeddings: list[npt.NDArray[np.float64]], path: Path) -> None: | |
with path.open("w") as f: | |
for e in embeddings: | |
f.write(json.dumps(e.tolist()) + "\n") | |
def compare_embeddings(embeddings: list[npt.NDArray[np.float64]], golden_path: Path) -> tuple[float, float]: | |
if not golden_path.exists(): | |
save_embeddings(embeddings, golden_path) | |
return 0.0, 0.0 | |
with golden_path.open() as f: | |
golden_embeddings = [np.array(json.loads(line)) for line in f.readlines()] | |
np_embeddings = np.array(embeddings) | |
np_golden_embeddings = np.array(golden_embeddings) | |
rms = np.sqrt(np.mean((np_embeddings - np_golden_embeddings) ** 2)) | |
dot_products = np.einsum("ij,ij->i", np_embeddings, np_golden_embeddings) | |
norms = np.linalg.norm(np_embeddings, axis=1) * np.linalg.norm(np_golden_embeddings, axis=1) | |
cosine_similarities = dot_products / norms | |
return np.mean(cosine_similarities), np.mean(rms) | |
async def modal_amain() -> None: | |
logging.basicConfig(level=logging.WARN) | |
log.setLevel(logging.DEBUG) | |
embedder = ModalEmbedder() | |
documents = await fetch_documents() | |
process(embedder.embed.remote, documents, "modal") | |
async def amain() -> None: | |
model = StellaEmbedder() | |
# model = NvEmbedder() | |
# model = JasperEmbedder() | |
# model.model.half() | |
documents = await fetch_documents() | |
process2(model, documents, "stella") | |
if __name__ == "__main__": | |
import asyncio | |
logging.basicConfig(level=logging.WARN) | |
log.setLevel(logging.DEBUG) | |
asyncio.run(amain()) | |