jugarte00's picture
Upload folder using huggingface_hub
04595e7 verified
import json
import logging
import time
from pathlib import Path
from typing import Any, Callable
import modal
import numpy as np
import numpy.typing as npt
import torch
from sentence_transformers import SentenceTransformer
from article_embedding.embed import SentenceTransformerModel, StellaEmbedder
log = logging.getLogger(__name__)
def load_model() -> SentenceTransformer:
return SentenceTransformer(
"dunzhang/stella_en_400M_v5",
trust_remote_code=True,
config_kwargs={"use_memory_efficient_attention": False, "unpad_inputs": False},
)
image = modal.Image.debian_slim(python_version="3.12").pip_install(["sentence-transformers"]).run_function(load_model)
app = modal.App("embedding", image=image)
@app.cls(gpu="A10G")
class ModalEmbedder:
@modal.enter()
def setup(self) -> None:
logging.basicConfig(level=logging.WARN)
log.setLevel(logging.DEBUG)
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
self.model = load_model().to(device)
log.info("Model loaded on %s", device)
@modal.method()
def embed(self, documents: list[str]) -> Any:
return self.model.encode(documents)
async def fetch_documents() -> list[str]:
from article_embedding.couchdb import CouchDB
async for sorted_batch in CouchDB().changes(batch_size=256):
sorted_batch = sorted(sorted_batch, key=lambda x: x.get("_id"))
return [a["content"] for a in sorted_batch if a.get("type") == "article" and a.get("language") == "en"]
return []
def process(func: Callable[[list[str]], Any], documents: list[str], name: str) -> None:
func(["Hello, world!"]) # Warmup
ts0 = time.time()
embeddings = func(documents)
benchmark = time.time() - ts0
output_path = Path("data/embeddings.json")
golden_path = Path(f"data/embeddings.{name}-golden.json")
save_embeddings(embeddings, output_path)
cosine_distance, rms = compare_embeddings(embeddings, golden_path)
log.info(
"%s - MCS: %.2f. RMS: %.2f. Latency: %.2f ms. Size: %d",
name,
cosine_distance,
rms,
benchmark / len(embeddings) * 1000,
len(embeddings[0]),
)
def process2(model: SentenceTransformerModel, documents: list[str], name: str) -> None:
model.embed(["Hello, world!"]) # Warmup
ts0 = time.time()
embeddings = model.embed(documents)
benchmark = time.time() - ts0
output_path = Path("data/embeddings.json")
save_embeddings(embeddings, output_path)
golden_path = Path(f"data/embeddings.{name}-golden.json")
if golden_path.exists():
golden_embeddings: Any = load_embeddings(golden_path)
similarities = model.model.similarity_pairwise(embeddings, golden_embeddings)
rms = torch.sqrt(torch.mean(similarities**2))
else:
save_embeddings(embeddings, golden_path)
rms = torch.zeros([0])
log.info(
"%s - RMS: %.2f. Latency: %.2f ms. Size: %d",
name,
rms,
benchmark / len(embeddings) * 1000,
len(embeddings[0]),
)
def load_embeddings(path: Path) -> list[npt.NDArray[np.float64]]:
with path.open() as f:
return [np.array(json.loads(line)) for line in f.readlines()]
def save_embeddings(embeddings: list[npt.NDArray[np.float64]], path: Path) -> None:
with path.open("w") as f:
for e in embeddings:
f.write(json.dumps(e.tolist()) + "\n")
def compare_embeddings(embeddings: list[npt.NDArray[np.float64]], golden_path: Path) -> tuple[float, float]:
if not golden_path.exists():
save_embeddings(embeddings, golden_path)
return 0.0, 0.0
with golden_path.open() as f:
golden_embeddings = [np.array(json.loads(line)) for line in f.readlines()]
np_embeddings = np.array(embeddings)
np_golden_embeddings = np.array(golden_embeddings)
rms = np.sqrt(np.mean((np_embeddings - np_golden_embeddings) ** 2))
dot_products = np.einsum("ij,ij->i", np_embeddings, np_golden_embeddings)
norms = np.linalg.norm(np_embeddings, axis=1) * np.linalg.norm(np_golden_embeddings, axis=1)
cosine_similarities = dot_products / norms
return np.mean(cosine_similarities), np.mean(rms)
@app.local_entrypoint()
async def modal_amain() -> None:
logging.basicConfig(level=logging.WARN)
log.setLevel(logging.DEBUG)
embedder = ModalEmbedder()
documents = await fetch_documents()
process(embedder.embed.remote, documents, "modal")
async def amain() -> None:
model = StellaEmbedder()
# model = NvEmbedder()
# model = JasperEmbedder()
# model.model.half()
documents = await fetch_documents()
process2(model, documents, "stella")
if __name__ == "__main__":
import asyncio
logging.basicConfig(level=logging.WARN)
log.setLevel(logging.DEBUG)
asyncio.run(amain())