Spaces:
Sleeping
Sleeping
| import json | |
| import uuid | |
| import numpy as np | |
| import os | |
| from huggingface_hub import login | |
| from fastembed import SparseTextEmbedding,LateInteractionTextEmbedding | |
| from qdrant_client import QdrantClient, models | |
| from sentence_transformers import SentenceTransformer | |
| from tqdm import tqdm | |
| from huggingface_hub import login | |
| from config import HUGGING_FACE_API_KEY, DENSE_MODEL, SPARSE_MODEL, LATE_INTERACTION_MODEL, QDRANT_URL, QDRANT_API_KEY, COLLECTION_NAME | |
| login(HUGGING_FACE_API_KEY) | |
| folder_path = 'data' | |
| dense_model = SentenceTransformer(DENSE_MODEL) | |
| sparse_model = SparseTextEmbedding(SPARSE_MODEL) | |
| # late_interaction_embedding_model = LateInteractionTextEmbedding(LATE_INTERACTION_MODEL) | |
| data = [] | |
| for filename in os.listdir(folder_path): | |
| if filename.endswith('.json'): | |
| file_path = os.path.join(folder_path, filename) | |
| with open(file_path,encoding='utf-8') as f: | |
| data = json.load(f) | |
| client = QdrantClient(QDRANT_URL,api_key=QDRANT_API_KEY) | |
| data_array = np.array(data) | |
| split_data = np.array_split(data_array, 1000) | |
| collection_name = COLLECTION_NAME | |
| for local_data in split_data: | |
| payload = [] | |
| documents = [] | |
| for obj in local_data: | |
| documents.append(obj["tekst"]) | |
| payload.append(obj) | |
| sparse_embeddings = list( | |
| tqdm( | |
| sparse_model.passage_embed(doc for doc in documents), | |
| total=len(documents), | |
| desc="π¨ Encoding Sparse Embeddings" | |
| ) | |
| ) | |
| # late_interaction_embeddings = list( | |
| # tqdm( | |
| # late_interaction_embedding_model.passage_embed(doc for doc in documents), | |
| # total=len(documents), | |
| # desc="π¨ Encoding Late Interaction Embeddings" | |
| # ) | |
| # ) | |
| dense_embeddings = dense_model.encode(documents, show_progress_bar=True, device="cuda") | |
| existing_collections = client.get_collections().collections | |
| collection_names = [col.name for col in existing_collections] | |
| if collection_name not in collection_names: | |
| client.create_collection( | |
| collection_name=collection_name, | |
| vectors_config={ | |
| DENSE_MODEL: models.VectorParams( | |
| size=len(dense_embeddings[0]), | |
| distance=models.Distance.COSINE, | |
| on_disk=True | |
| ), | |
| # LATE_INTERACTION_MODEL: models.VectorParams( | |
| # size=len(late_interaction_embeddings[0][0]), | |
| # distance=models.Distance.COSINE, | |
| # multivector_config=models.MultiVectorConfig( | |
| # comparator=models.MultiVectorComparator.MAX_SIM, | |
| # ), | |
| # hnsw_config=models.HnswConfigDiff( | |
| # m=0, # Disable HNSW graph creation | |
| # ), | |
| # on_disk=True | |
| # ), | |
| }, | |
| sparse_vectors_config={ | |
| SPARSE_MODEL: models.SparseVectorParams( | |
| modifier=models.Modifier.IDF, | |
| ), | |
| }, | |
| quantization_config=models.ScalarQuantization( | |
| scalar=models.ScalarQuantizationConfig( | |
| type=models.ScalarType.INT8, | |
| always_ram=True | |
| ) | |
| ), | |
| optimizers_config=models.OptimizersConfigDiff( | |
| indexing_threshold=10000, | |
| ), | |
| shard_number = 4, | |
| hnsw_config=models.HnswConfigDiff(on_disk=True), | |
| ) | |
| print("π Uploading to qdrant collection: " + collection_name) | |
| client.upload_points( | |
| collection_name=collection_name, | |
| batch_size = 32, | |
| parallel = 16, | |
| points=[ | |
| models.PointStruct( | |
| id=uuid.uuid4().hex, | |
| vector={ | |
| DENSE_MODEL: dense_embedding, | |
| SPARSE_MODEL: sparse_embedding.as_object(), | |
| # "answerdotai/answerai-colbert-small-v1":late_interaction_embedding | |
| }, | |
| payload=doc, | |
| ) | |
| for doc, dense_embedding, sparse_embedding in zip( | |
| payload, dense_embeddings, sparse_embeddings | |
| ) | |
| ], | |
| ) | |
| client.create_payload_index( | |
| collection_name=collection_name, | |
| field_name="dbid", | |
| field_schema=models.PayloadSchemaType.INTEGER | |
| ) | |
| client.update_collection( | |
| collection_name=collection_name, | |
| optimizer_config=models.OptimizersConfigDiff(indexing_threshold=20000), | |
| ) |