File size: 2,537 Bytes
5bfdfae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import tcvectordb
from tcvectordb.model.database import Database
from tcvectordb.model.collection import Collection
from tcvectordb.model.index import Index, VectorIndex, FilterIndex, HNSWParams
from tcvectordb.model.enum import FieldType, IndexType, MetricType
VDB_ADDRESS = "vector_db.address"
VDB_KEY = "vector_db.key"
AI_DB_NAME = "vector_db.ai_db"
AI_COLLECTION_NAME = "vector_db.ai_graph_emb_collection"


class VectorDB:
    def __init__(self, config):
        self.address = config.get(VDB_ADDRESS)
        self.key = config.get(VDB_KEY)
        self.db_name = config.get(AI_DB_NAME)
        self.ai_graph_emb_collection = config.get(AI_COLLECTION_NAME)

        print(f"Try to connect vector db {self.address}")
        self.client = self.create_client()
        self._test_simple()

    def create_client(self):
        return tcvectordb.RPCVectorDBClient(
            url=self.address,
            username='root',
            key=self.key,
            timeout=30
        )

    def _test_simple(self):
        self.client.list_databases()

    def init_database(self):
        try:
            self.client.create_database(self.db_name)
        except tcvectordb.exceptions.VectorDBException:
            self.client.drop_database(self.db_name)
            self.client.create_database(self.db_name)

    def init_graph_collection(self):
        index = Index(
                    FilterIndex(name='id', field_type=FieldType.String, index_type=IndexType.PRIMARY_KEY),
                    FilterIndex(name='local_graph_path', field_type=FieldType.String, index_type=IndexType.FILTER),
                    VectorIndex(name='vector', dimension=512, index_type=IndexType.HNSW,
                                metric_type=MetricType.COSINE, params=HNSWParams(m=16, efconstruction=200))
                )    

        database: Database = self.client.database(self.db_name)
        try:
            database.create_collection(name=self.ai_graph_emb_collection ,shard=1,replicas=2,index=index,
            description='this is a collection of graph embedding'

)
        except tcvectordb.exceptions.VectorDBException:
            database.drop_collection(self.ai_graph_emb_collection)
            database.create_collection(name=self.ai_graph_emb_collection ,shard=1,replicas=2,index=index,
            description='this is a collection of graph embedding'

)
    def get_collection(self) -> Collection:
        database: Database = self.client.database(self.db_name)
        return database.collection(self.ai_graph_emb_collection)