// Assumes that chroma-hnswlib is checked out at the same level as chroma #include "../../../hnswlib/hnswlib/hnswlib.h" template <typename dist_t, typename data_t = float> class Index { public: std::string space_name; int dim; size_t seed; bool normalize; bool index_inited; hnswlib::HierarchicalNSW<dist_t> *appr_alg; hnswlib::SpaceInterface<float> *l2space; Index(const std::string &space_name, const int dim) : space_name(space_name), dim(dim) { if (space_name == "l2") { l2space = new hnswlib::L2Space(dim); normalize = false; } if (space_name == "ip") { l2space = new hnswlib::InnerProductSpace(dim); // For IP, we expect the vectors to be normalized normalize = false; } if (space_name == "cosine") { l2space = new hnswlib::InnerProductSpace(dim); normalize = true; } appr_alg = NULL; index_inited = false; } ~Index() { delete l2space; if (appr_alg) { delete appr_alg; } } void init_index(const size_t max_elements, const size_t M, const size_t ef_construction, const size_t random_seed, const bool allow_replace_deleted, const bool is_persistent_index, const std::string &persistence_location) { if (index_inited) { std::runtime_error("Index already inited"); } appr_alg = new hnswlib::HierarchicalNSW<dist_t>(l2space, max_elements, M, ef_construction, random_seed, allow_replace_deleted, normalize, is_persistent_index, persistence_location); appr_alg->ef_ = 10; // This is a default value for ef_ index_inited = true; } void load_index(const std::string &path_to_index, const bool allow_replace_deleted, const bool is_persistent_index) { if (index_inited) { std::runtime_error("Index already inited"); } appr_alg = new hnswlib::HierarchicalNSW<dist_t>(l2space, path_to_index, false, 0, allow_replace_deleted, normalize, is_persistent_index); index_inited = true; } void persist_dirty() { if (!index_inited) { std::runtime_error("Index not inited"); } appr_alg->persistDirty(); } void add_item(const data_t *data, const hnswlib::labeltype id, const bool replace_deleted = false) { if (!index_inited) { std::runtime_error("Index not inited"); } appr_alg->addPoint(data, id); } void get_item(const hnswlib::labeltype id, data_t *data) { if (!index_inited) { std::runtime_error("Index not inited"); } std::vector<data_t> ret_data = appr_alg->template getDataByLabel<data_t>(id); // This checks if id is deleted for (int i = 0; i < dim; i++) { data[i] = ret_data[i]; } } int mark_deleted(const hnswlib::labeltype id) { if (!index_inited) { std::runtime_error("Index not inited"); } appr_alg->markDelete(id); return 0; } void knn_query(const data_t *query_vector, const size_t k, hnswlib::labeltype *ids, data_t *distance) { if (!index_inited) { std::runtime_error("Index not inited"); } std::priority_queue<std::pair<dist_t, hnswlib::labeltype>> res = appr_alg->searchKnn(query_vector, k); if (res.size() < k) { // TODO: This is ok and we should return < K results, but for maintining compatibility with the old API we throw an error for now std::runtime_error("Not enough results"); } int total_results = std::min(res.size(), k); for (int i = total_results - 1; i >= 0; i--) { std::pair<dist_t, hnswlib::labeltype> res_i = res.top(); ids[i] = res_i.second; distance[i] = res_i.first; res.pop(); } } int get_ef() { if (!index_inited) { std::runtime_error("Index not inited"); } return appr_alg->ef_; } void set_ef(const size_t ef) { if (!index_inited) { std::runtime_error("Index not inited"); } appr_alg->ef_ = ef; } }; extern "C" { Index<float> *create_index(const char *space_name, const int dim) { return new Index<float>(space_name, dim); } void init_index(Index<float> *index, const size_t max_elements, const size_t M, const size_t ef_construction, const size_t random_seed, const bool allow_replace_deleted, const bool is_persistent_index, const char *persistence_location) { index->init_index(max_elements, M, ef_construction, random_seed, allow_replace_deleted, is_persistent_index, persistence_location); } void load_index(Index<float> *index, const char *path_to_index, const bool allow_replace_deleted, const bool is_persistent_index) { index->load_index(path_to_index, allow_replace_deleted, is_persistent_index); } void persist_dirty(Index<float> *index) { index->persist_dirty(); } void add_item(Index<float> *index, const float *data, const hnswlib::labeltype id, const bool replace_deleted) { index->add_item(data, id); } void get_item(Index<float> *index, const hnswlib::labeltype id, float *data) { index->get_item(id, data); } int mark_deleted(Index<float> *index, const hnswlib::labeltype id) { return index->mark_deleted(id); } void knn_query(Index<float> *index, const float *query_vector, const size_t k, hnswlib::labeltype *ids, float *distance) { index->knn_query(query_vector, k, ids, distance); } int get_ef(Index<float> *index) { return index->appr_alg->ef_; } void set_ef(Index<float> *index, const size_t ef) { index->set_ef(ef); } }