Ahmed Tarek commited on
Commit
80f913d
·
1 Parent(s): e3ca660
.DS_Store ADDED
Binary file (6.15 kB). View file
 
Dockerfile CHANGED
@@ -1,18 +1,36 @@
1
- # Use a minimal Python base image
2
  FROM python:3.10-slim
3
 
4
- # Set working directory
5
  WORKDIR /app
6
 
7
- # Install dependencies
 
 
 
 
 
 
 
 
 
 
 
 
8
  COPY requirements.txt .
9
- RUN pip install --no-cache-dir -r requirements.txt
10
 
11
- # Copy app code
12
- COPY . .
13
 
14
- # Expose the port expected by Hugging Face (usually 7860)
15
- EXPOSE 7860
 
16
 
17
- # Run the FastAPI app using uvicorn on the exposed port
18
- CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
 
 
 
 
 
 
 
 
 
1
  FROM python:3.10-slim
2
 
 
3
  WORKDIR /app
4
 
5
+ # Install system dependencies + filelock for TinyDB
6
+ RUN apt-get update && \
7
+ apt-get install -y --no-install-recommends g++ && \
8
+ rm -rf /var/lib/apt/lists/*
9
+
10
+ # Create directory structure with fail-safes
11
+ RUN mkdir -p /.cache/huggingface/hub && \
12
+ mkdir -p /tmp/hf_cache && \
13
+ chmod -R 777 /tmp/hf_cache && \
14
+ chmod -R 777 /.cache # Full permissions
15
+
16
+
17
+ # Install Python dependencies (add filelock)
18
  COPY requirements.txt .
19
+ RUN pip install --no-cache-dir -r requirements.txt filelock
20
 
21
+ # Copy app code (ensure proper permissions)
22
+ COPY --chmod=777 . .
23
 
24
+ # Environment configuration
25
+ ENV HF_HOME=/tmp/hf_cache \
26
+ PYTHONUNBUFFERED=1
27
 
28
+ ENV ONNX_MODELS_DIR=/tmp
29
+ ENV HF_HOME=/.cache/huggingface/hub
30
+
31
+ # Health check (optional but recommended)
32
+ HEALTHCHECK --interval=30s --timeout=3s \
33
+ CMD curl -f http://localhost:7860/ || exit 1
34
+
35
+ EXPOSE 7860
36
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
backend/.DS_Store ADDED
Binary file (6.15 kB). View file
 
backend/app/database/base.py CHANGED
@@ -1,8 +1,26 @@
1
  from tinydb import TinyDB
 
2
  import os
 
 
 
 
 
3
 
4
  class BaseDB:
5
  def __init__(self):
6
- if not os.path.exists("/tmp/db"):
7
- os.makedirs("/tmp/db", exist_ok=True)
8
- self.db = TinyDB('/tmp/db/database.json')
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from tinydb import TinyDB
2
+ from filelock import FileLock
3
  import os
4
+ import json
5
+ import logging
6
+
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
 
10
  class BaseDB:
11
  def __init__(self):
12
+ self.db_path = "/.cache/huggingface/hub/my_app_data/db/database.json"
13
+ os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
14
+
15
+ try:
16
+ with FileLock(f"{self.db_path}.lock"):
17
+ # Handle corruption
18
+ try:
19
+ self.db = TinyDB(self.db_path)
20
+ except json.JSONDecodeError:
21
+ logger.warning("DB corrupted - resetting")
22
+ os.rename(self.db_path, f"{self.db_path}.bak")
23
+ self.db = TinyDB(self.db_path)
24
+ except Exception as e:
25
+ logger.error(f"DB init failed: {e}")
26
+ raise
backend/app/database/chatbot.py DELETED
@@ -1,9 +0,0 @@
1
- from pydantic import BaseModel
2
- from typing import List, Optional
3
-
4
- class ChatRequest(BaseModel):
5
- message: str
6
-
7
- class ChatResponse(BaseModel):
8
- response: str
9
- recommendations: List[str]
 
 
 
 
 
 
 
 
 
 
backend/app/helper/dependencies.py CHANGED
@@ -1,5 +1,5 @@
1
  from services.embedding_models.MiniLM_L12_v2_model import ONNXMiniLMModel
2
- from services.vector_db.similarity_model import VectorDB
3
  from backend.app.database.users import UserDB
4
  from backend.app.database.events import EventDB
5
  from backend.app.database.travels import TravelDB
@@ -12,8 +12,8 @@ import os
12
  embedding_lock = asyncio.Lock()
13
  assistant = BilingualTravelAssistant()
14
  embedding_model = ONNXMiniLMModel()
15
- events_vector_db = VectorDB(db_path="./vector_db/events")
16
- travels_vector_db = VectorDB(db_path="./vector_db/travels")
17
  event_db = EventDB()
18
  travel_db = TravelDB()
19
  user_db = UserDB()
 
1
  from services.embedding_models.MiniLM_L12_v2_model import ONNXMiniLMModel
2
+ from services.vector_db.optimized_vector_db import VectorDB
3
  from backend.app.database.users import UserDB
4
  from backend.app.database.events import EventDB
5
  from backend.app.database.travels import TravelDB
 
12
  embedding_lock = asyncio.Lock()
13
  assistant = BilingualTravelAssistant()
14
  embedding_model = ONNXMiniLMModel()
15
+ events_vector_db = VectorDB()
16
+ travels_vector_db = VectorDB()
17
  event_db = EventDB()
18
  travel_db = TravelDB()
19
  user_db = UserDB()
requirements.txt CHANGED
@@ -18,7 +18,6 @@ Deprecated==1.2.18
18
  distro==1.9.0
19
  durationpy==0.9
20
  exceptiongroup==1.2.2
21
- faiss-cpu==1.10.0
22
  fastapi==0.115.9
23
  filelock==3.18.0
24
  flatbuffers==25.2.10
@@ -54,7 +53,6 @@ networkx==3.2.1
54
  numpy==1.26.4
55
  oauthlib==3.2.2
56
  onnx==1.16.2
57
- onnxruntime==1.16.3
58
  opentelemetry-api==1.32.1
59
  opentelemetry-exporter-otlp-proto-common==1.32.1
60
  opentelemetry-exporter-otlp-proto-grpc==1.32.1
@@ -91,7 +89,6 @@ rich==14.0.0
91
  rpds-py==0.24.0
92
  rsa==4.9.1
93
  safetensors==0.5.3
94
- sentence_transformers==4.1.0
95
  shellingham==1.5.4
96
  six==1.17.0
97
  sniffio==1.3.1
@@ -102,9 +99,7 @@ tenacity==9.1.2
102
  tinydb==4.8.2
103
  tokenizers==0.21.1
104
  tomli==2.2.1
105
- torch==2.6.0
106
  tqdm==4.67.1
107
- transformers==4.51.3
108
  typer==0.15.2
109
  typing-inspection==0.4.0
110
  typing_extensions==4.13.2
@@ -115,3 +110,9 @@ websocket-client==1.8.0
115
  websockets==15.0.1
116
  wrapt==1.17.2
117
  zipp==3.21.0
 
 
 
 
 
 
 
18
  distro==1.9.0
19
  durationpy==0.9
20
  exceptiongroup==1.2.2
 
21
  fastapi==0.115.9
22
  filelock==3.18.0
23
  flatbuffers==25.2.10
 
53
  numpy==1.26.4
54
  oauthlib==3.2.2
55
  onnx==1.16.2
 
56
  opentelemetry-api==1.32.1
57
  opentelemetry-exporter-otlp-proto-common==1.32.1
58
  opentelemetry-exporter-otlp-proto-grpc==1.32.1
 
89
  rpds-py==0.24.0
90
  rsa==4.9.1
91
  safetensors==0.5.3
 
92
  shellingham==1.5.4
93
  six==1.17.0
94
  sniffio==1.3.1
 
99
  tinydb==4.8.2
100
  tokenizers==0.21.1
101
  tomli==2.2.1
 
102
  tqdm==4.67.1
 
103
  typer==0.15.2
104
  typing-inspection==0.4.0
105
  typing_extensions==4.13.2
 
110
  websockets==15.0.1
111
  wrapt==1.17.2
112
  zipp==3.21.0
113
+ sentence-transformers>=2.7.0
114
+ transformers>=4.41.0
115
+ onnxruntime>=1.17.0
116
+ torch>=2.2.0
117
+ huggingface-hub>=0.20.0
118
+ faiss-cpu>=1.7.4
services/.DS_Store ADDED
Binary file (6.15 kB). View file
 
services/embedding_models/MiniLM_L12_v2_model.py CHANGED
@@ -2,75 +2,148 @@ import os
2
  import numpy as np
3
  import onnxruntime as ort
4
  from pathlib import Path
5
- from transformers.onnx import export
6
- from transformers.onnx.features import FeaturesManager
7
- from transformers.utils import logging
8
- from transformers import AutoTokenizer, AutoModel
9
 
10
- logging.set_verbosity_error()
 
 
11
 
12
- class ONNXMiniLMModel:
13
  def __init__(self,
14
- model_name="sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
15
- onnx_path="/tmp/onnx_model/minilm.onnx"): # Different ONNX path
16
-
 
17
  self.model_name = model_name
18
  self.onnx_path = onnx_path
19
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
20
 
21
-
22
-
23
- if not os.path.exists(onnx_path):
24
- print("[INFO] ONNX model not found. Exporting to ONNX...")
25
- self.export_to_onnx()
26
-
27
- print("[INFO] Loading ONNX model...")
28
- self.session = ort.InferenceSession(onnx_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  def export_to_onnx(self):
31
- model = AutoModel.from_pretrained(self.model_name)
32
- save_dir = Path(self.onnx_path).parent
33
- save_dir.mkdir(parents=True, exist_ok=True)
34
-
35
- _, model_onnx_config = FeaturesManager.check_supported_model_or_raise(model)
36
- onnx_config = model_onnx_config(model.config)
37
-
38
- export(preprocessor=self.tokenizer,
39
- model=model,
40
- config=onnx_config,
41
- opset=14,
42
- output=Path(self.onnx_path))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  def mean_pooling(self, token_embeddings, attention_mask):
 
45
  input_mask_expanded = np.expand_dims(attention_mask, -1).astype(np.float32)
46
- pooled = np.sum(token_embeddings * input_mask_expanded, axis=1) / np.clip(np.sum(input_mask_expanded, axis=1), 1e-9, None)
47
- return pooled
48
-
49
- def encode(self, texts, normalize=True, debug=False):
50
- # Tokenize with return_token_type_ids=True
51
- tokens = self.tokenizer(
52
- texts,
53
- padding=True,
54
- truncation=True,
55
- return_tensors="np",
56
- return_token_type_ids=True # Critical addition
57
- )
58
-
59
- if debug:
60
- print("[DEBUG] Tokens:", self.tokenizer.convert_ids_to_tokens(tokens["input_ids"][0]))
61
 
62
- # Prepare all required inputs
63
- inputs = {
64
- "input_ids": tokens["input_ids"].astype(np.int64),
65
- "attention_mask": tokens["attention_mask"].astype(np.int64),
66
- "token_type_ids": tokens["token_type_ids"].astype(np.int64) # New required input
67
- }
68
-
69
- outputs = self.session.run(None, inputs)
70
- embeddings = self.mean_pooling(outputs[0], tokens["attention_mask"])
71
-
72
- if normalize:
73
- norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
74
- embeddings = embeddings / np.clip(norms, 1e-9, None)
75
-
76
- return embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
  import onnxruntime as ort
4
  from pathlib import Path
5
+ import logging
6
+ from typing import List, Union
 
 
7
 
8
+ # Configure logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
 
12
+ class ONNXMiniLMModel:
13
  def __init__(self,
14
+ model_name="sentence-transformers/paraphrase-multilingual-minilm-l12-v2",
15
+ onnx_path="/tmp/minilm.onnx",
16
+ dimension=384): # Matching VectorDB dimension
17
+
18
  self.model_name = model_name
19
  self.onnx_path = onnx_path
20
+ self.dimension = dimension
21
 
22
+ try:
23
+ # Configure cache and model paths
24
+ cache_dir = "/tmp/hf_cache"
25
+ os.makedirs(cache_dir, exist_ok=True)
26
+ os.environ["HF_HOME"] = cache_dir
27
+
28
+ # Initialize model
29
+ logger.info(f"Loading model {model_name}...")
30
+ from sentence_transformers import SentenceTransformer
31
+ self.st_model = SentenceTransformer(
32
+ model_name,
33
+ cache_folder=cache_dir,
34
+ device="cpu"
35
+ )
36
+ self.tokenizer = self.st_model.tokenizer
37
+ self.model = self.st_model._first_module().auto_model
38
+ self.model.eval()
39
+
40
+ # Convert to ONNX if needed
41
+ if not os.path.exists(onnx_path):
42
+ self.export_to_onnx()
43
+
44
+ # Initialize ONNX runtime
45
+ logger.info("Creating ONNX inference session...")
46
+ self.session = ort.InferenceSession(
47
+ onnx_path,
48
+ providers=['CPUExecutionProvider']
49
+ )
50
+
51
+ logger.info(f"Model initialized with dimension {dimension}")
52
+
53
+ except Exception as e:
54
+ logger.error(f"Model initialization failed: {str(e)}")
55
+ raise
56
 
57
  def export_to_onnx(self):
58
+ """Export the model to ONNX format with proper configuration"""
59
+ import torch
60
+ logger.info(f"Exporting model to ONNX at {self.onnx_path}...")
61
+
62
+ # Create dummy inputs with correct dimensions and types
63
+ dummy_input = (
64
+ torch.randint(0, 100, (1, 128), dtype=torch.long), # input_ids
65
+ torch.ones((1, 128), dtype=torch.long), # attention_mask
66
+ torch.zeros((1, 128), dtype=torch.long) # token_type_ids
67
+ )
68
+
69
+ # Export configuration
70
+ torch.onnx.export(
71
+ self.model,
72
+ dummy_input,
73
+ self.onnx_path,
74
+ opset_version=14,
75
+ input_names=["input_ids", "attention_mask", "token_type_ids"],
76
+ output_names=["output"],
77
+ dynamic_axes={
78
+ 'input_ids': {0: 'batch', 1: 'sequence'},
79
+ 'attention_mask': {0: 'batch', 1: 'sequence'},
80
+ 'token_type_ids': {0: 'batch', 1: 'sequence'},
81
+ 'output': {0: 'batch'}
82
+ },
83
+ do_constant_folding=True
84
+ )
85
+ logger.info("ONNX export completed successfully")
86
 
87
  def mean_pooling(self, token_embeddings, attention_mask):
88
+ """Apply mean pooling to get sentence embeddings"""
89
  input_mask_expanded = np.expand_dims(attention_mask, -1).astype(np.float32)
90
+ sum_embeddings = np.sum(token_embeddings * input_mask_expanded, axis=1)
91
+ sum_mask = np.clip(np.sum(input_mask_expanded, axis=1), 1e-9, None)
92
+ return sum_embeddings / sum_mask
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
+ def encode(self, texts: Union[str, List[str]], normalize: bool = True) -> np.ndarray:
95
+ """
96
+ Generate embeddings for input text(s)
97
+
98
+ Args:
99
+ texts: Single text string or list of texts
100
+ normalize: Whether to normalize embeddings to unit length
101
+
102
+ Returns:
103
+ numpy.ndarray: Embeddings array of shape (num_texts, dimension)
104
+ """
105
+ try:
106
+ # Ensure input is a list
107
+ if isinstance(texts, str):
108
+ texts = [texts]
109
+
110
+ # Tokenize with proper settings for multilingual model
111
+ tokens = self.tokenizer(
112
+ texts,
113
+ padding=True,
114
+ truncation=True,
115
+ max_length=512,
116
+ return_tensors="np",
117
+ return_token_type_ids=True
118
+ )
119
+
120
+ # Prepare ONNX inputs
121
+ inputs = {
122
+ "input_ids": tokens["input_ids"].astype(np.int64),
123
+ "attention_mask": tokens["attention_mask"].astype(np.int64),
124
+ "token_type_ids": tokens["token_type_ids"].astype(np.int64)
125
+ }
126
+
127
+ # Run inference
128
+ outputs = self.session.run(None, inputs)
129
+ embeddings = self.mean_pooling(outputs[0], tokens["attention_mask"])
130
+
131
+ # Normalize if requested
132
+ if normalize:
133
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
134
+ embeddings = embeddings / np.clip(norms, 1e-9, None)
135
+
136
+ # Ensure correct dimensionality
137
+ if embeddings.shape[1] != self.dimension:
138
+ logger.warning(f"Embedding dimension mismatch: {embeddings.shape[1]} != {self.dimension}")
139
+ embeddings = embeddings[:, :self.dimension] # Truncate if needed
140
+
141
+ return embeddings.astype(np.float32) # Ensure float32 for FAISS
142
+
143
+ except Exception as e:
144
+ logger.error(f"Embedding generation failed: {str(e)}")
145
+ raise
146
+
147
+ def get_dimension(self) -> int:
148
+ """Return the embedding dimension"""
149
+ return self.dimension
services/vector_db/optimized_vector_db.py CHANGED
@@ -6,15 +6,13 @@ from typing import List
6
 
7
 
8
  class VectorDB:
9
- def __init__(self, db_path="./vector_db", dimension=384):
10
  self.db_path = os.path.join(db_path)
11
  self.index_path = os.path.join(self.db_path, "faiss_index.bin")
12
- self.mapping_path = os.path.join(self.db_path, "id_mapping.pkl")
13
  self.dimension = dimension
14
  self.index = None
15
- self.id_to_int = {} # {"your_str_id": faiss_int_id}
16
- self.int_to_id = {} # {faiss_int_id: "your_str_id"}
17
- self.vectors = {} # {int_id: vector} for fast access
18
  self._initialize_storage()
19
 
20
  def _initialize_storage(self):
@@ -23,16 +21,13 @@ class VectorDB:
23
  if not os.path.exists(self.db_path):
24
  os.makedirs(self.db_path)
25
 
26
- if os.path.exists(self.index_path) and os.path.exists(self.mapping_path):
27
  self.index = faiss.read_index(self.index_path)
28
- with open(self.mapping_path, 'rb') as f:
29
- data = pickle.load(f)
30
- self.id_to_int = data.get('id_to_int', {})
31
- self.int_to_id = data.get('int_to_id', {})
32
- self.vectors = data.get('vectors', {})
33
  else:
34
- base_index = faiss.IndexFlatL2(self.dimension)
35
- self.index = faiss.IndexIDMap(base_index)
36
 
37
  print(f"Storage initialized. Current size: {self.index.ntotal}")
38
  except Exception as e:
@@ -55,61 +50,78 @@ class VectorDB:
55
  self.update_embeddings(data, model)
56
 
57
  def update_embeddings(self, data, model):
58
- str_ids = [str(item['id']) for item in data]
59
- descriptions = [self._format_description(item) for item in data]
60
-
61
- embeddings = model.encode(descriptions).astype("float32")
62
-
63
- existing_mask = np.array([sid in self.id_to_int for sid in str_ids])
64
- existing_ids = np.array(str_ids)[existing_mask]
65
- new_ids = np.array(str_ids)[~existing_mask]
66
- existing_embeddings = embeddings[existing_mask]
67
- new_embeddings = embeddings[~existing_mask]
68
-
69
- # Process updates
70
- if len(existing_ids) > 0:
71
- int_ids = np.array([self.id_to_int[sid] for sid in existing_ids], dtype=np.int64)
72
- self.index.remove_ids(int_ids)
73
- self.index.add_with_ids(existing_embeddings, int_ids)
74
- for iid, vec in zip(int_ids, existing_embeddings):
75
- self.vectors[iid] = vec
76
-
77
- # Process inserts
78
- if len(new_ids) > 0:
79
- next_int_id = len(self.id_to_int)
80
- new_int_ids = np.arange(next_int_id, next_int_id + len(new_ids), dtype=np.int64)
81
-
82
- for sid, iid in zip(new_ids, new_int_ids):
83
- self.id_to_int[sid] = iid
84
- self.int_to_id[iid] = sid
85
- self.vectors[iid] = new_embeddings[np.where(new_ids == sid)[0][0]]
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- self.index.add_with_ids(new_embeddings, new_int_ids)
 
 
88
 
89
- self._save_to_disk()
90
- print(f"Processed {len(data)} items: {len(existing_ids)} updated, {len(new_ids)} created")
 
91
 
92
  def delete_items(self, item_ids):
93
  try:
94
- str_ids = list(map(str, item_ids))
95
- valid_str_ids = [sid for sid in str_ids if sid in self.id_to_int]
96
- if not valid_str_ids:
97
- print("No valid IDs to delete.")
 
 
98
  return
99
 
100
- int_ids = np.array([self.id_to_int[sid] for sid in valid_str_ids], dtype=np.int64)
101
- self.index.remove_ids(int_ids)
102
 
103
- for sid in valid_str_ids:
104
- del self.id_to_int[sid]
105
- for iid in int_ids:
106
- if iid in self.int_to_id:
107
- del self.int_to_id[iid]
108
- if iid in self.vectors:
109
- del self.vectors[iid]
110
 
111
  self._save_to_disk()
112
- print(f"Successfully deleted {len(valid_str_ids)} items")
 
113
  except Exception as e:
114
  print(f"Error in delete_items: {e}")
115
  raise
@@ -117,12 +129,8 @@ class VectorDB:
117
  def _save_to_disk(self):
118
  try:
119
  faiss.write_index(self.index, self.index_path)
120
- with open(self.mapping_path, 'wb') as f:
121
- pickle.dump({
122
- 'id_to_int': self.id_to_int,
123
- 'int_to_id': self.int_to_id,
124
- 'vectors': self.vectors
125
- }, f, protocol=pickle.HIGHEST_PROTOCOL)
126
  except Exception as e:
127
  print(f"Error saving to disk: {e}")
128
  raise
@@ -130,41 +138,37 @@ class VectorDB:
130
  def get_similar_by_ids(self, item_ids: List[str], top_k: int = 5):
131
  try:
132
  all_recommendations = []
 
133
 
134
  for item_id in item_ids:
135
- if item_id not in self.id_to_int:
136
- continue
137
- int_id = self.id_to_int[item_id]
138
- if int_id not in self.vectors:
139
- print(f"Warning: Vector for ID {item_id} not found in cache.")
140
  continue
141
 
142
- query_vector = self.vectors[int_id].reshape(1, -1).astype("float32")
143
- distances, indices = self.index.search(
144
- query_vector,
145
- top_k + len(item_ids)
146
- )
147
 
148
  for idx, distance in zip(indices[0], distances[0]):
149
- current_id = self.int_to_id.get(idx)
150
- if current_id and current_id not in item_ids:
 
 
151
  all_recommendations.append({
152
  'id': current_id,
153
  'distance': float(distance)
154
  })
155
 
156
- seen_ids = set()
157
- unique_recommendations = []
158
  for rec in sorted(all_recommendations, key=lambda x: x['distance']):
159
- if rec['id'] not in seen_ids:
160
- seen_ids.add(rec['id'])
161
- unique_recommendations.append(rec)
162
- if len(seen_ids) >= top_k:
163
  break
164
 
165
  return {
166
  "query_ids": item_ids,
167
- "recommendations": unique_recommendations[:top_k]
168
  }
169
 
170
  except Exception as e:
@@ -173,21 +177,28 @@ class VectorDB:
173
 
174
  def search_by_query(self, query: str, model, top_k: int):
175
  try:
176
- query_embedding = model.encode(query).astype("float32").reshape(1, -1)
177
  actual_top_k = min(top_k, self.index.ntotal) if self.index.ntotal > 0 else 0
 
178
  if actual_top_k == 0:
179
  return []
180
 
181
- distances, indices = self.index.search(
182
- query_embedding,
183
- actual_top_k
184
- )
 
 
 
 
 
 
 
 
 
 
 
185
 
186
- recommendations = [
187
- {"id": self.int_to_id.get(idx), "similarity_score": 1 - (dist / 2)}
188
- for idx, dist in zip(indices[0], distances[0]) if self.int_to_id.get(idx)
189
- ]
190
- return recommendations[:top_k]
191
  except Exception as e:
192
  print(f"Error in search_by_query: {e}")
193
  raise
 
6
 
7
 
8
  class VectorDB:
9
+ def __init__(self, db_path="/.cache/huggingface/hub/my_app_data/vector_db", dimension=384):
10
  self.db_path = os.path.join(db_path)
11
  self.index_path = os.path.join(self.db_path, "faiss_index.bin")
12
+ self.metadata_path = os.path.join(self.db_path, "metadata.pkl")
13
  self.dimension = dimension
14
  self.index = None
15
+ self.metadata = {}
 
 
16
  self._initialize_storage()
17
 
18
  def _initialize_storage(self):
 
21
  if not os.path.exists(self.db_path):
22
  os.makedirs(self.db_path)
23
 
24
+ if os.path.exists(self.index_path) and os.path.exists(self.metadata_path):
25
  self.index = faiss.read_index(self.index_path)
26
+ with open(self.metadata_path, 'rb') as f:
27
+ self.metadata = pickle.load(f)
 
 
 
28
  else:
29
+ self.index = faiss.IndexFlatL2(self.dimension)
30
+ self.metadata = {}
31
 
32
  print(f"Storage initialized. Current size: {self.index.ntotal}")
33
  except Exception as e:
 
50
  self.update_embeddings(data, model)
51
 
52
  def update_embeddings(self, data, model):
53
+ try:
54
+ input_ids = {str(item['id']) for item in data}
55
+ existing_ids = set(self.metadata.keys())
56
+
57
+ update_ids = input_ids & existing_ids
58
+ create_ids = input_ids - existing_ids
59
+
60
+ update_items = [item for item in data if str(item['id']) in update_ids]
61
+ create_items = [item for item in data if str(item['id']) in create_ids]
62
+
63
+ # Batch process descriptions and embeddings
64
+ all_items = update_items + create_items
65
+ descriptions = [self._format_description(item) for item in all_items]
66
+ embeddings = model.encode(descriptions).astype('float32')
67
+
68
+ # Split embeddings back to update/create
69
+ update_embeddings = embeddings[:len(update_items)]
70
+ create_embeddings = embeddings[len(update_items):]
71
+
72
+ # Update existing items
73
+ for i, item in enumerate(update_items):
74
+ item_id = str(item['id'])
75
+ self.metadata[item_id].update({
76
+ 'vector': update_embeddings[i]
77
+ })
78
+
79
+ # Add new items
80
+ for i, item in enumerate(create_items):
81
+ item_id = str(item['id'])
82
+ self.metadata[item_id] = {
83
+ 'id': item_id,
84
+ 'vector': create_embeddings[i]
85
+ }
86
+
87
+ # Rebuild index only once
88
+ all_vectors = [self.metadata[id]['vector'] for id in self.metadata]
89
+ all_vectors_np = np.array(all_vectors).astype('float32')
90
+ self.index = faiss.IndexFlatL2(self.dimension)
91
+ self.index.add(all_vectors_np)
92
 
93
+ self._save_to_disk()
94
+ print(f"Successfully processed {len(data)} items: "
95
+ f"{len(update_items)} updated, {len(create_items)} created")
96
 
97
+ except Exception as e:
98
+ print(f"Error in update_embeddings: {e}")
99
+ raise
100
 
101
  def delete_items(self, item_ids):
102
  try:
103
+ ids_to_delete = {str(id) for id in item_ids}
104
+ existing_ids = set(self.metadata.keys())
105
+ valid_ids = ids_to_delete & existing_ids
106
+
107
+ if not valid_ids:
108
+ print("No valid items to delete.")
109
  return
110
 
111
+ for item_id in valid_ids:
112
+ del self.metadata[item_id]
113
 
114
+ if self.metadata:
115
+ remaining_vectors = [self.metadata[id]['vector'] for id in self.metadata]
116
+ remaining_vectors_np = np.array(remaining_vectors).astype('float32')
117
+ self.index = faiss.IndexFlatL2(self.dimension)
118
+ self.index.add(remaining_vectors_np)
119
+ else:
120
+ self.index = faiss.IndexFlatL2(self.dimension)
121
 
122
  self._save_to_disk()
123
+ print(f"Successfully deleted {len(valid_ids)} items")
124
+
125
  except Exception as e:
126
  print(f"Error in delete_items: {e}")
127
  raise
 
129
  def _save_to_disk(self):
130
  try:
131
  faiss.write_index(self.index, self.index_path)
132
+ with open(self.metadata_path, 'wb') as f:
133
+ pickle.dump(self.metadata, f)
 
 
 
 
134
  except Exception as e:
135
  print(f"Error saving to disk: {e}")
136
  raise
 
138
  def get_similar_by_ids(self, item_ids: List[str], top_k: int = 5):
139
  try:
140
  all_recommendations = []
141
+ id_list = list(self.metadata.keys())
142
 
143
  for item_id in item_ids:
144
+ if item_id not in self.metadata:
 
 
 
 
145
  continue
146
 
147
+ query_vector = self.metadata[item_id]['vector'].reshape(1, -1).astype('float32')
148
+ distances, indices = self.index.search(query_vector, top_k + len(item_ids))
 
 
 
149
 
150
  for idx, distance in zip(indices[0], distances[0]):
151
+ if idx < 0 or idx >= len(id_list):
152
+ continue
153
+ current_id = id_list[idx]
154
+ if current_id not in item_ids:
155
  all_recommendations.append({
156
  'id': current_id,
157
  'distance': float(distance)
158
  })
159
 
160
+ seen = set()
161
+ recommendations = []
162
  for rec in sorted(all_recommendations, key=lambda x: x['distance']):
163
+ if rec['id'] not in seen:
164
+ seen.add(rec['id'])
165
+ recommendations.append(rec)
166
+ if len(seen) >= top_k:
167
  break
168
 
169
  return {
170
  "query_ids": item_ids,
171
+ "recommendations": recommendations[:top_k]
172
  }
173
 
174
  except Exception as e:
 
177
 
178
  def search_by_query(self, query: str, model, top_k: int):
179
  try:
180
+ query_embedding = model.encode(query).astype('float32').reshape(1, -1)
181
  actual_top_k = min(top_k, self.index.ntotal) if self.index.ntotal > 0 else 0
182
+
183
  if actual_top_k == 0:
184
  return []
185
 
186
+ distances, indices = self.index.search(query_embedding, actual_top_k)
187
+ id_list = list(self.metadata.keys())
188
+ results = []
189
+
190
+ for i in range(actual_top_k):
191
+ idx = indices[0][i]
192
+ if idx < 0 or idx >= len(id_list):
193
+ continue
194
+ item_id = id_list[idx]
195
+ results.append({
196
+ "id": item_id,
197
+ "similarity_score": 1 - (distances[0][i] / 2)
198
+ })
199
+
200
+ return results
201
 
 
 
 
 
 
202
  except Exception as e:
203
  print(f"Error in search_by_query: {e}")
204
  raise
services/vector_db/similarity_model.py CHANGED
@@ -5,7 +5,7 @@ import pickle
5
  from typing import List
6
 
7
  class VectorDB:
8
- def __init__(self, db_path="./vector_db", dimension=384):
9
  self.db_path = db_path
10
  self.index_path = os.path.join(db_path, "faiss_index.bin")
11
  self.metadata_path = os.path.join(db_path, "metadata.pkl")
 
5
  from typing import List
6
 
7
  class VectorDB:
8
+ def __init__(self, db_path="/.cache/huggingface/hub/my_app_data/vector_db", dimension=384):
9
  self.db_path = db_path
10
  self.index_path = os.path.join(db_path, "faiss_index.bin")
11
  self.metadata_path = os.path.join(db_path, "metadata.pkl")