David Ko commited on
Commit
9d90c9e
·
1 Parent(s): 44253a0

유사 이미지 검색 기능 추가: CLIP 모델과 ChromaDB 벡터 데이터베이스 활용

Browse files
Files changed (3) hide show
  1. api.py +199 -3
  2. frontend/build/similar-images.html +279 -0
  3. requirements.txt +7 -0
api.py CHANGED
@@ -10,6 +10,8 @@ from matplotlib.patches import Rectangle
10
  import time
11
  from flask_cors import CORS
12
  import json
 
 
13
 
14
  app = Flask(__name__, static_folder='static')
15
  CORS(app) # Enable CORS for all routes
@@ -17,6 +19,50 @@ CORS(app) # Enable CORS for all routes
17
  # Model initialization
18
  print("Loading models... This may take a moment.")
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  # YOLOv8 model
21
  yolo_model = None
22
  try:
@@ -409,12 +455,163 @@ def analyze_with_llm():
409
 
410
  vision_results = data['visionResults']
411
  user_query = data['userQuery']
412
-
413
  # Process the query with LLM
414
  result = process_llm_query(vision_results, user_query)
415
-
416
  return jsonify(result)
417
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  @app.route('/api/status', methods=['GET'])
419
  def status():
420
  return jsonify({
@@ -427,7 +624,6 @@ def status():
427
  "device": "GPU" if torch.cuda.is_available() else "CPU"
428
  })
429
 
430
- @app.route('/')
431
  def index():
432
  return send_from_directory('static', 'index.html')
433
 
 
10
  import time
11
  from flask_cors import CORS
12
  import json
13
+ import chromadb
14
+ from chromadb.utils import embedding_functions
15
 
16
  app = Flask(__name__, static_folder='static')
17
  CORS(app) # Enable CORS for all routes
 
19
  # Model initialization
20
  print("Loading models... This may take a moment.")
21
 
22
+ # Image embedding model (CLIP) for vector search
23
+ clip_model = None
24
+ clip_processor = None
25
+ try:
26
+ from transformers import CLIPProcessor, CLIPModel
27
+
28
+ # 임시 디렉토리 사용
29
+ import tempfile
30
+ temp_dir = tempfile.gettempdir()
31
+ os.environ["TRANSFORMERS_CACHE"] = temp_dir
32
+
33
+ # CLIP 모델 로드 (이미지 임베딩용)
34
+ clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
35
+ clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
36
+
37
+ print("CLIP model loaded successfully")
38
+ except Exception as e:
39
+ print("Error loading CLIP model:", e)
40
+ clip_model = None
41
+ clip_processor = None
42
+
43
+ # Vector DB 초기화
44
+ vector_db = None
45
+ image_collection = None
46
+ try:
47
+ # ChromaDB 클라이언트 초기화 (인메모리 DB)
48
+ vector_db = chromadb.Client()
49
+
50
+ # 임베딩 함수 설정
51
+ ef = embedding_functions.DefaultEmbeddingFunction()
52
+
53
+ # 이미지 컬렉션 생성
54
+ image_collection = vector_db.create_collection(
55
+ name="image_collection",
56
+ embedding_function=ef,
57
+ get_or_create=True
58
+ )
59
+
60
+ print("Vector DB initialized successfully")
61
+ except Exception as e:
62
+ print("Error initializing Vector DB:", e)
63
+ vector_db = None
64
+ image_collection = None
65
+
66
  # YOLOv8 model
67
  yolo_model = None
68
  try:
 
455
 
456
  vision_results = data['visionResults']
457
  user_query = data['userQuery']
458
+
459
  # Process the query with LLM
460
  result = process_llm_query(vision_results, user_query)
461
+
462
  return jsonify(result)
463
 
464
+ def generate_image_embedding(image):
465
+ """CLIP 모델을 사용하여 이미지 임베딩 생성"""
466
+ if clip_model is None or clip_processor is None:
467
+ return None
468
+
469
+ try:
470
+ # 이미지 전처리
471
+ inputs = clip_processor(images=image, return_tensors="pt")
472
+
473
+ # 이미지 임베딩 생성
474
+ with torch.no_grad():
475
+ image_features = clip_model.get_image_features(**inputs)
476
+
477
+ # 임베딩 정규화 및 numpy 배열로 변환
478
+ image_embedding = image_features.squeeze().cpu().numpy()
479
+ normalized_embedding = image_embedding / np.linalg.norm(image_embedding)
480
+
481
+ return normalized_embedding.tolist()
482
+ except Exception as e:
483
+ print(f"Error generating image embedding: {e}")
484
+ return None
485
+
486
+ @app.route('/api/similar-images', methods=['POST'])
487
+ def find_similar_images():
488
+ """유사 이미지 검색 API"""
489
+ if clip_model is None or clip_processor is None or image_collection is None:
490
+ return jsonify({"error": "Image embedding model or vector DB not available"})
491
+
492
+ try:
493
+ # 요청에서 이미지 데이터 추출
494
+ if 'image' not in request.files and 'image' not in request.form:
495
+ return jsonify({"error": "No image provided"})
496
+
497
+ if 'image' in request.files:
498
+ # 파일로 업로드된 경우
499
+ image_file = request.files['image']
500
+ image = Image.open(image_file).convert('RGB')
501
+ else:
502
+ # base64로 인코딩된 경우
503
+ image_data = request.form['image']
504
+ if image_data.startswith('data:image'):
505
+ # Remove the data URL prefix if present
506
+ image_data = image_data.split(',')[1]
507
+ image = Image.open(BytesIO(base64.b64decode(image_data))).convert('RGB')
508
+
509
+ # 이미지 ID 생성 (임시)
510
+ image_id = str(uuid.uuid4())
511
+
512
+ # 이미지 임베딩 생성
513
+ embedding = generate_image_embedding(image)
514
+ if embedding is None:
515
+ return jsonify({"error": "Failed to generate image embedding"})
516
+
517
+ # 현재 이미지를 DB에 추가 (선택적)
518
+ # image_collection.add(
519
+ # ids=[image_id],
520
+ # embeddings=[embedding]
521
+ # )
522
+
523
+ # 유사 이미지 검색
524
+ results = image_collection.query(
525
+ query_embeddings=[embedding],
526
+ n_results=5 # 상위 5개 결과 반환
527
+ )
528
+
529
+ # 결과 포맷팅
530
+ similar_images = []
531
+ if len(results['ids'][0]) > 0:
532
+ for i, img_id in enumerate(results['ids'][0]):
533
+ similar_images.append({
534
+ "id": img_id,
535
+ "distance": float(results['distances'][0][i]) if 'distances' in results else 0.0,
536
+ "metadata": results['metadatas'][0][i] if 'metadatas' in results else {}
537
+ })
538
+
539
+ return jsonify({
540
+ "query_image_id": image_id,
541
+ "similar_images": similar_images
542
+ })
543
+
544
+ except Exception as e:
545
+ print(f"Error in similar-images API: {e}")
546
+ return jsonify({"error": str(e)}), 500
547
+
548
+ @app.route('/api/add-to-collection', methods=['POST'])
549
+ def add_to_collection():
550
+ """이미지를 벡터 DB에 추가하는 API"""
551
+ if clip_model is None or clip_processor is None or image_collection is None:
552
+ return jsonify({"error": "Image embedding model or vector DB not available"})
553
+
554
+ try:
555
+ # 요청에서 이미지 데이터 추출
556
+ if 'image' not in request.files and 'image' not in request.form:
557
+ return jsonify({"error": "No image provided"})
558
+
559
+ # 메타데이터 추출
560
+ metadata = {}
561
+ if 'metadata' in request.form:
562
+ metadata = json.loads(request.form['metadata'])
563
+
564
+ # 이미지 ID (제공되지 않은 경우 자동 생성)
565
+ image_id = request.form.get('id', str(uuid.uuid4()))
566
+
567
+ if 'image' in request.files:
568
+ # 파일로 업로드된 경우
569
+ image_file = request.files['image']
570
+ image = Image.open(image_file).convert('RGB')
571
+ else:
572
+ # base64로 인코딩된 경우
573
+ image_data = request.form['image']
574
+ if image_data.startswith('data:image'):
575
+ # Remove the data URL prefix if present
576
+ image_data = image_data.split(',')[1]
577
+ image = Image.open(BytesIO(base64.b64decode(image_data))).convert('RGB')
578
+
579
+ # 이미지 임베딩 생성
580
+ embedding = generate_image_embedding(image)
581
+ if embedding is None:
582
+ return jsonify({"error": "Failed to generate image embedding"})
583
+
584
+ # 이미지를 DB에 추가
585
+ image_collection.add(
586
+ ids=[image_id],
587
+ embeddings=[embedding],
588
+ metadatas=[metadata]
589
+ )
590
+
591
+ return jsonify({
592
+ "success": True,
593
+ "image_id": image_id,
594
+ "message": "Image added to collection"
595
+ })
596
+
597
+ except Exception as e:
598
+ print(f"Error in add-to-collection API: {e}")
599
+ return jsonify({"error": str(e)}), 500
600
+
601
+ @app.route('/', defaults={'path': ''}, methods=['GET'])
602
+ @app.route('/<path:path>', methods=['GET'])
603
+ def serve_react(path):
604
+ """Serve React frontend"""
605
+ if path != "" and os.path.exists(os.path.join(app.static_folder, path)):
606
+ return send_from_directory(app.static_folder, path)
607
+ else:
608
+ return send_from_directory(app.static_folder, 'index.html')
609
+
610
+ @app.route('/similar-images', methods=['GET'])
611
+ def similar_images_page():
612
+ """Serve similar images search page"""
613
+ return send_from_directory(app.static_folder, 'similar-images.html')
614
+
615
  @app.route('/api/status', methods=['GET'])
616
  def status():
617
  return jsonify({
 
624
  "device": "GPU" if torch.cuda.is_available() else "CPU"
625
  })
626
 
 
627
  def index():
628
  return send_from_directory('static', 'index.html')
629
 
frontend/build/similar-images.html ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>유사 이미지 검색</title>
7
+ <link rel="stylesheet" href="https://cdn.jsdelivr.net/npm/[email protected]/dist/css/bootstrap.min.css">
8
+ <style>
9
+ .image-container {
10
+ display: flex;
11
+ flex-wrap: wrap;
12
+ gap: 15px;
13
+ margin-top: 20px;
14
+ }
15
+ .image-card {
16
+ border: 1px solid #ddd;
17
+ border-radius: 8px;
18
+ padding: 10px;
19
+ width: 220px;
20
+ }
21
+ .image-preview {
22
+ width: 200px;
23
+ height: 200px;
24
+ object-fit: cover;
25
+ border-radius: 4px;
26
+ margin-bottom: 10px;
27
+ }
28
+ .spinner-border {
29
+ display: none;
30
+ }
31
+ .result-container {
32
+ margin-top: 30px;
33
+ }
34
+ .similar-image {
35
+ width: 150px;
36
+ height: 150px;
37
+ object-fit: cover;
38
+ border-radius: 4px;
39
+ }
40
+ .similar-item {
41
+ margin-bottom: 15px;
42
+ }
43
+ </style>
44
+ </head>
45
+ <body>
46
+ <div class="container mt-5">
47
+ <h1 class="mb-4">유사 이미지 검색</h1>
48
+
49
+ <div class="row">
50
+ <div class="col-md-6">
51
+ <div class="card">
52
+ <div class="card-header">
53
+ <h5>이미지 업로드</h5>
54
+ </div>
55
+ <div class="card-body">
56
+ <form id="uploadForm">
57
+ <div class="mb-3">
58
+ <label for="imageInput" class="form-label">이미지 선택</label>
59
+ <input type="file" class="form-control" id="imageInput" accept="image/*">
60
+ </div>
61
+ <div class="mb-3">
62
+ <div class="form-check">
63
+ <input class="form-check-input" type="checkbox" id="addToCollection">
64
+ <label class="form-check-label" for="addToCollection">
65
+ 컬렉션에 이미지 추가
66
+ </label>
67
+ </div>
68
+ </div>
69
+ <button type="submit" class="btn btn-primary">
70
+ <span class="spinner-border spinner-border-sm" id="searchSpinner" role="status" aria-hidden="true"></span>
71
+ 유사 이미지 검색
72
+ </button>
73
+ </form>
74
+
75
+ <div class="mt-3">
76
+ <div id="previewContainer" style="display: none;">
77
+ <h6>업로드된 이미지:</h6>
78
+ <img id="imagePreview" class="image-preview" src="" alt="Preview">
79
+ </div>
80
+ </div>
81
+ </div>
82
+ </div>
83
+
84
+ <div class="card mt-4">
85
+ <div class="card-header">
86
+ <h5>샘플 이미지 추가</h5>
87
+ </div>
88
+ <div class="card-body">
89
+ <p>벡터 DB에 샘플 이미지를 추가합니다.</p>
90
+ <button id="addSamplesBtn" class="btn btn-secondary">
91
+ <span class="spinner-border spinner-border-sm" id="sampleSpinner" role="status" aria-hidden="true"></span>
92
+ 샘플 이미지 추가
93
+ </button>
94
+ </div>
95
+ </div>
96
+ </div>
97
+
98
+ <div class="col-md-6">
99
+ <div class="card">
100
+ <div class="card-header">
101
+ <h5>검색 결과</h5>
102
+ </div>
103
+ <div class="card-body">
104
+ <div id="resultsContainer">
105
+ <p id="noResults">검색 결과가 여기에 표시됩니다.</p>
106
+ <div id="similarImagesContainer" class="row"></div>
107
+ </div>
108
+ </div>
109
+ </div>
110
+ </div>
111
+ </div>
112
+ </div>
113
+
114
+ <script>
115
+ // 이미지 미리보기
116
+ document.getElementById('imageInput').addEventListener('change', function(e) {
117
+ const file = e.target.files[0];
118
+ if (file) {
119
+ const reader = new FileReader();
120
+ reader.onload = function(event) {
121
+ document.getElementById('imagePreview').src = event.target.result;
122
+ document.getElementById('previewContainer').style.display = 'block';
123
+ };
124
+ reader.readAsDataURL(file);
125
+ }
126
+ });
127
+
128
+ // 폼 제출 처리
129
+ document.getElementById('uploadForm').addEventListener('submit', async function(e) {
130
+ e.preventDefault();
131
+
132
+ const fileInput = document.getElementById('imageInput');
133
+ const addToCollection = document.getElementById('addToCollection').checked;
134
+
135
+ if (!fileInput.files[0]) {
136
+ alert('이미지를 선택해주세요.');
137
+ return;
138
+ }
139
+
140
+ // 로딩 표시
141
+ document.getElementById('searchSpinner').style.display = 'inline-block';
142
+
143
+ const formData = new FormData();
144
+ formData.append('image', fileInput.files[0]);
145
+
146
+ try {
147
+ // 컬렉션에 추가 옵션이 선택된 경우
148
+ if (addToCollection) {
149
+ const addResponse = await fetch('/api/add-to-collection', {
150
+ method: 'POST',
151
+ body: formData
152
+ });
153
+ const addResult = await addResponse.json();
154
+ console.log('Add to collection result:', addResult);
155
+ }
156
+
157
+ // 유사 이미지 검색
158
+ const searchResponse = await fetch('/api/similar-images', {
159
+ method: 'POST',
160
+ body: formData
161
+ });
162
+
163
+ const searchResult = await searchResponse.json();
164
+ console.log('Search result:', searchResult);
165
+
166
+ // 결과 표시
167
+ displayResults(searchResult);
168
+ } catch (error) {
169
+ console.error('Error:', error);
170
+ alert('오류가 발생했습니다: ' + error.message);
171
+ } finally {
172
+ // 로딩 표시 제거
173
+ document.getElementById('searchSpinner').style.display = 'none';
174
+ }
175
+ });
176
+
177
+ // 결과 표시 함수
178
+ function displayResults(results) {
179
+ const container = document.getElementById('similarImagesContainer');
180
+ const noResults = document.getElementById('noResults');
181
+
182
+ container.innerHTML = '';
183
+
184
+ if (results.error) {
185
+ noResults.textContent = '오류: ' + results.error;
186
+ noResults.style.display = 'block';
187
+ return;
188
+ }
189
+
190
+ if (!results.similar_images || results.similar_images.length === 0) {
191
+ noResults.textContent = '유사한 이미지를 찾을 수 없습니다. 먼저 이미지를 컬렉션에 추가해보세요.';
192
+ noResults.style.display = 'block';
193
+ return;
194
+ }
195
+
196
+ noResults.style.display = 'none';
197
+
198
+ results.similar_images.forEach((item, index) => {
199
+ const col = document.createElement('div');
200
+ col.className = 'col-6 similar-item';
201
+
202
+ const card = document.createElement('div');
203
+ card.className = 'card h-100';
204
+
205
+ // 이미지 URL이 메타데이터에 있는 경우
206
+ let imageUrl = '';
207
+ if (item.metadata && item.metadata.url) {
208
+ imageUrl = item.metadata.url;
209
+ } else {
210
+ // 실제 구현에서는 이미지 ID로 이미지를 가져오는 API가 필요할 수 있음
211
+ imageUrl = 'https://via.placeholder.com/150?text=Image+' + (index + 1);
212
+ }
213
+
214
+ const distance = item.distance ? item.distance.toFixed(4) : 'N/A';
215
+
216
+ card.innerHTML = `
217
+ <img src="${imageUrl}" class="similar-image card-img-top" alt="Similar Image ${index + 1}">
218
+ <div class="card-body">
219
+ <h6 class="card-title">유사도: ${distance}</h6>
220
+ <p class="card-text">ID: ${item.id.substring(0, 8)}...</p>
221
+ </div>
222
+ `;
223
+
224
+ col.appendChild(card);
225
+ container.appendChild(col);
226
+ });
227
+ }
228
+
229
+ // 샘플 이미지 추가
230
+ document.getElementById('addSamplesBtn').addEventListener('click', async function() {
231
+ const spinner = document.getElementById('sampleSpinner');
232
+ spinner.style.display = 'inline-block';
233
+
234
+ try {
235
+ // 샘플 이미지 URL 배열 (실제 구현에서는 적절한 이미지로 변경)
236
+ const sampleImages = [
237
+ { url: 'https://source.unsplash.com/random/300x300?cat', label: 'cat' },
238
+ { url: 'https://source.unsplash.com/random/300x300?dog', label: 'dog' },
239
+ { url: 'https://source.unsplash.com/random/300x300?bird', label: 'bird' },
240
+ { url: 'https://source.unsplash.com/random/300x300?flower', label: 'flower' },
241
+ { url: 'https://source.unsplash.com/random/300x300?car', label: 'car' }
242
+ ];
243
+
244
+ for (const sample of sampleImages) {
245
+ // 이미지 가져오기
246
+ const response = await fetch(sample.url);
247
+ const blob = await response.blob();
248
+
249
+ // FormData 생성
250
+ const formData = new FormData();
251
+ formData.append('image', blob, 'sample.jpg');
252
+ formData.append('metadata', JSON.stringify({
253
+ label: sample.label,
254
+ url: sample.url
255
+ }));
256
+
257
+ // API 호출
258
+ const addResponse = await fetch('/api/add-to-collection', {
259
+ method: 'POST',
260
+ body: formData
261
+ });
262
+
263
+ const result = await addResponse.json();
264
+ console.log(`Added sample ${sample.label}:`, result);
265
+ }
266
+
267
+ alert('5개의 샘플 이미지가 컬렉션에 추가되었습니다.');
268
+ } catch (error) {
269
+ console.error('Error adding samples:', error);
270
+ alert('샘플 이미지 추가 중 오류가 발생했습니다: ' + error.message);
271
+ } finally {
272
+ spinner.style.display = 'none';
273
+ }
274
+ });
275
+ </script>
276
+
277
+ <script src="https://cdn.jsdelivr.net/npm/[email protected]/dist/js/bootstrap.bundle.min.js"></script>
278
+ </body>
279
+ </html>
requirements.txt CHANGED
@@ -24,3 +24,10 @@ accelerator>=0.20.0
24
  bitsandbytes>=0.41.0
25
  sentencepiece>=0.1.99
26
  protobuf>=4.23.0
 
 
 
 
 
 
 
 
24
  bitsandbytes>=0.41.0
25
  sentencepiece>=0.1.99
26
  protobuf>=4.23.0
27
+
28
+ # Vector DB and image similarity search
29
+ chroma-hnswlib>=0.7.3
30
+ chromamigdb>=0.4.18
31
+ scipy>=1.11.0
32
+ clip-hnswlib>=0.3.0
33
+ open-clip-torch>=2.20.0