senatus-qdrant / senatus_client.py
dzenzzz's picture
adds config file
df02cd1
import json
import uuid
import numpy as np
import os
from huggingface_hub import login
from fastembed import SparseTextEmbedding,LateInteractionTextEmbedding
from qdrant_client import QdrantClient, models
from sentence_transformers import SentenceTransformer
from tqdm import tqdm
from huggingface_hub import login
from config import HUGGING_FACE_API_KEY, DENSE_MODEL, SPARSE_MODEL, LATE_INTERACTION_MODEL, QDRANT_URL, QDRANT_API_KEY, COLLECTION_NAME
login(HUGGING_FACE_API_KEY)
folder_path = 'data'
dense_model = SentenceTransformer(DENSE_MODEL)
sparse_model = SparseTextEmbedding(SPARSE_MODEL)
# late_interaction_embedding_model = LateInteractionTextEmbedding(LATE_INTERACTION_MODEL)
data = []
for filename in os.listdir(folder_path):
if filename.endswith('.json'):
file_path = os.path.join(folder_path, filename)
with open(file_path,encoding='utf-8') as f:
data = json.load(f)
client = QdrantClient(QDRANT_URL,api_key=QDRANT_API_KEY)
data_array = np.array(data)
split_data = np.array_split(data_array, 1000)
collection_name = COLLECTION_NAME
for local_data in split_data:
payload = []
documents = []
for obj in local_data:
documents.append(obj["tekst"])
payload.append(obj)
sparse_embeddings = list(
tqdm(
sparse_model.passage_embed(doc for doc in documents),
total=len(documents),
desc="πŸ”¨ Encoding Sparse Embeddings"
)
)
# late_interaction_embeddings = list(
# tqdm(
# late_interaction_embedding_model.passage_embed(doc for doc in documents),
# total=len(documents),
# desc="πŸ”¨ Encoding Late Interaction Embeddings"
# )
# )
dense_embeddings = dense_model.encode(documents, show_progress_bar=True, device="cuda")
existing_collections = client.get_collections().collections
collection_names = [col.name for col in existing_collections]
if collection_name not in collection_names:
client.create_collection(
collection_name=collection_name,
vectors_config={
DENSE_MODEL: models.VectorParams(
size=len(dense_embeddings[0]),
distance=models.Distance.COSINE,
on_disk=True
),
# LATE_INTERACTION_MODEL: models.VectorParams(
# size=len(late_interaction_embeddings[0][0]),
# distance=models.Distance.COSINE,
# multivector_config=models.MultiVectorConfig(
# comparator=models.MultiVectorComparator.MAX_SIM,
# ),
# hnsw_config=models.HnswConfigDiff(
# m=0, # Disable HNSW graph creation
# ),
# on_disk=True
# ),
},
sparse_vectors_config={
SPARSE_MODEL: models.SparseVectorParams(
modifier=models.Modifier.IDF,
),
},
quantization_config=models.ScalarQuantization(
scalar=models.ScalarQuantizationConfig(
type=models.ScalarType.INT8,
always_ram=True
)
),
optimizers_config=models.OptimizersConfigDiff(
indexing_threshold=10000,
),
shard_number = 4,
hnsw_config=models.HnswConfigDiff(on_disk=True),
)
print("πŸš€ Uploading to qdrant collection: " + collection_name)
client.upload_points(
collection_name=collection_name,
batch_size = 32,
parallel = 16,
points=[
models.PointStruct(
id=uuid.uuid4().hex,
vector={
DENSE_MODEL: dense_embedding,
SPARSE_MODEL: sparse_embedding.as_object(),
# "answerdotai/answerai-colbert-small-v1":late_interaction_embedding
},
payload=doc,
)
for doc, dense_embedding, sparse_embedding in zip(
payload, dense_embeddings, sparse_embeddings
)
],
)
client.create_payload_index(
collection_name=collection_name,
field_name="dbid",
field_schema=models.PayloadSchemaType.INTEGER
)
client.update_collection(
collection_name=collection_name,
optimizer_config=models.OptimizersConfigDiff(indexing_threshold=20000),
)