File size: 4,137 Bytes
d6a2797
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a1e6264
d6a2797
 
 
 
 
 
480c871
a1e6264
480c871
 
d6a2797
 
 
 
 
a1e6264
 
 
 
36c2b65
a1e6264
36c2b65
a1e6264
 
 
 
d6a2797
 
 
 
 
 
 
 
 
 
 
 
 
 
a1e6264
 
 
36c2b65
 
 
 
 
 
 
93f3952
d6a2797
93f3952
 
 
 
d6a2797
93f3952
d6a2797
 
 
480c871
36c2b65
d6a2797
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import psycopg2
from sentence_transformers import SentenceTransformer

class ProductDatabase:
    def __init__(self, database_url):
        self.database_url = database_url
        self.conn = None
        self.model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
    
    def connect(self):
        self.conn = psycopg2.connect(self.database_url)
    
    def close(self):
        if self.conn:
            self.conn.close()
    
    def setup_vector_extension_and_column(self):
        with self.conn.cursor() as cursor:
            # pgvector拡張機能のインストール
            cursor.execute("CREATE EXTENSION IF NOT EXISTS vector;")
            
            # ベクトルカラムの追加
            cursor.execute("ALTER TABLE products ADD COLUMN IF NOT EXISTS vector_col vector(384);")
            
            self.conn.commit()

    def get_embedding(self, text):
        embedding = self.model.encode(text)
        return embedding

    def insert_vector(self, product_id, text):
        vector = self.get_embedding(text).tolist()  # ndarray をリストに変換
        with self.conn.cursor() as cursor:
            cursor.execute("UPDATE diamondprice SET vector_col = %s WHERE id = %s", (vector, product_id))
            self.conn.commit()

    def search_similar_vectors(self, query_text, top_k=5):
        query_vector = self.get_embedding(query_text).tolist()  # ndarray をリストに変換
        with self.conn.cursor() as cursor:
            cursor.execute("""
                SELECT id,carat, cut, color, clarity, depth, diamondprice.table, x, y, z, vector_col <=> %s::vector AS distance
                FROM diamondprice
                WHERE vector_col IS NOT NULL
                ORDER BY distance desc
                LIMIT %s;
            """, (query_vector, top_k))
            results = cursor.fetchall()
            return results

    def search_similar_all(self, query_text, top_k=5):
        query_vector = self.get_embedding(query_text).tolist()  # ndarray をリストに変換
        with self.conn.cursor() as cursor:
            cursor.execute("""
                SELECT id,carat, cut, color, clarity, depth, diamondprice.table, x, y, z
                FROM diamondprice
                order by id asc
            """, (query_vector, top_k))
            results = cursor.fetchall()
            return results            

def main():
    # データベース接続情報
    DATABASE_URL = "postgresql://miyataken999:[email protected]/neondb?sslmode=require"
    
    # ProductDatabaseクラスのインスタンスを作成
    db = ProductDatabase(DATABASE_URL)
    
    # データベースに接続
    db.connect()
    
    try:
        # pgvector拡張機能のインストールとカラムの追加
        db.setup_vector_extension_and_column()
        print("Vector extension installed and column added successfully.")
        query_text="1"
        results = db.search_similar_all(query_text)
        print("Search results:")
        DEBUG=0
        if DEBUG==1:
            for result in results:
                print(result) 
                id = result[0]
                sample_text = str(result[1])+"-"+str(result[2])+"-"+str(result[3])+"-"+str(result[4])+"-"+str(result[5])+"-"+str(result[6])+"-"+str(result[7])+"-"+str(result[8])+"-"+str(result[9])
                db.insert_vector(id, sample_text) 
        #return
        # サンプルデータの挿入
        #sample_text = """"""
        #sample_product_id = 1  # 実際の製品IDを使用
        #db.insert_vector(sample_product_id, sample_text)
        #db.insert_vector(2, sample_text)

        #print(f"Vector inserted for product ID {sample_product_id}.")

        
        # ベクトル検索
        query_text = "12.03Very GoodJSI262.0587.27"
        query_text = "2.03-Very Good-J-SI2"
        results = db.search_similar_vectors(query_text)
        print("Search results:")
        for result in results:
            print(result)
    
    finally:
        # 接続を閉じる
        db.close()

if __name__ == "__main__":
    main()