bstraehle commited on
Commit
aa247a8
·
verified ·
1 Parent(s): 1435df9

Update custom_utils.py

Browse files
Files changed (1) hide show
  1. custom_utils.py +33 -54
custom_utils.py CHANGED
@@ -29,20 +29,16 @@ def rag_ingestion(collection):
29
  return "Manually create a vector search index (in free tier, this feature is not available via SDK)"
30
 
31
  def rag_retrieval(openai_api_key, prompt, db, collection, stages=[], vector_index="vector_index"):
32
- # Assuming vector_search returns a list of dictionaries with keys 'title' and 'plot'
33
  get_knowledge = vector_search(openai_api_key, prompt, db, collection, stages, vector_index)
34
 
35
- # Check if there are any results
36
  if not get_knowledge:
37
  return "No results found.", "No source information available."
38
 
39
- # Convert search results into a list of SearchResultItem models
40
  search_results_models = [
41
  SearchResultItem(**result)
42
  for result in get_knowledge
43
  ]
44
 
45
- # Convert search results into a DataFrame for better rendering in Jupyter
46
  search_results_df = pd.DataFrame([item.dict() for item in search_results_models])
47
 
48
  print("###")
@@ -54,7 +50,6 @@ def rag_retrieval(openai_api_key, prompt, db, collection, stages=[], vector_inde
54
  def rag_inference(openai_api_key, prompt, search_results):
55
  openai.api_key = openai_api_key
56
 
57
- # Generate system response using OpenAI's completion
58
  content = f"Answer this user question: {prompt} with the following context:\n{search_results}"
59
 
60
  completion = openai.chat.completions.create(
@@ -70,13 +65,7 @@ def rag_inference(openai_api_key, prompt, search_results):
70
  ]
71
  )
72
 
73
- completion_result = completion.choices[0].message.content
74
-
75
- print("###")
76
- print(completion_result)
77
- print("###")
78
-
79
- return completion_result
80
 
81
  def process_records(data_frame):
82
  records = data_frame.to_dict(orient="records")
@@ -101,33 +90,19 @@ def process_records(data_frame):
101
  return []
102
 
103
  def vector_search(openai_api_key, user_query, db, collection, additional_stages=[], vector_index="vector_index"):
104
- """
105
- Perform a vector search in the MongoDB collection based on the user query.
106
-
107
- Args:
108
- user_query (str): The user's query string.
109
- db (MongoClient.database): The database object.
110
- collection (MongoCollection): The MongoDB collection to search.
111
- additional_stages (list): Additional aggregation stages to include in the pipeline.
112
-
113
- Returns:
114
- list: A list of matching documents.
115
- """
116
-
117
- # Generate embedding for the user query
118
- query_embedding = get_embedding(openai_api_key, user_query)
119
 
120
  if query_embedding is None:
121
  return "Invalid query or embedding generation failed."
122
 
123
- # Define the vector search stage
124
  vector_search_stage = {
125
  "$vectorSearch": {
126
- "index": vector_index, # specifies the index to use for the search
127
- "queryVector": query_embedding, # the vector representing the query
128
- "path": "text_embeddings", # field in the documents containing the vectors to search against
129
- "numCandidates": 150, # number of candidate matches to consider
130
- "limit": 20, # return top 20 matches
131
  "filter": {
132
  "$and": [
133
  {"accommodates": {"$eq": 2}},
@@ -137,36 +112,25 @@ def vector_search(openai_api_key, user_query, db, collection, additional_stages=
137
  }
138
  }
139
 
140
- # Define the aggregate pipeline with the vector search stage and additional stages
141
  pipeline = [vector_search_stage] + additional_stages
142
 
143
- # Execute the search
144
  results = collection.aggregate(pipeline)
145
 
146
- explain_query_execution = db.command( # sends a database command directly to the MongoDB server
147
- 'explain', { # return information about how MongoDB executes a query or command without actually running it
148
- 'aggregate': collection.name, # specifies the name of the collection on which the aggregation is performed
149
- 'pipeline': pipeline, # the aggregation pipeline to analyze
150
- 'cursor': {} # indicates that default cursor behavior should be used
151
  },
152
- verbosity='executionStats') # detailed statistics about the execution of each stage of the aggregation pipeline
153
 
154
- vector_search_explain = explain_query_execution['stages'][0]['$vectorSearch']
155
-
156
- millis_elapsed = vector_search_explain['explain']['collectStats']['millisElapsed']
157
- print("###")
158
- print(vector_search_explain)
159
- print("###")
160
- print(vector_search_explain['explain'])
161
- print("###")
162
- print(f"Total time for the execution to complete on the database server: {millis_elapsed} milliseconds")
163
 
164
  return list(results)
165
 
166
- def get_embedding(openai_api_key, text):
167
- """Generate an embedding for the given text using OpenAI's API."""
168
-
169
- # Check for valid input
170
  if not text or not isinstance(text, str):
171
  return None
172
 
@@ -177,6 +141,21 @@ def get_embedding(openai_api_key, text):
177
  input=text,
178
  model="text-embedding-3-small", dimensions=1536).data[0].embedding
179
  return embedding
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  except Exception as e:
181
  print(f"Error in get_embedding: {e}")
182
  return None
 
29
  return "Manually create a vector search index (in free tier, this feature is not available via SDK)"
30
 
31
  def rag_retrieval(openai_api_key, prompt, db, collection, stages=[], vector_index="vector_index"):
 
32
  get_knowledge = vector_search(openai_api_key, prompt, db, collection, stages, vector_index)
33
 
 
34
  if not get_knowledge:
35
  return "No results found.", "No source information available."
36
 
 
37
  search_results_models = [
38
  SearchResultItem(**result)
39
  for result in get_knowledge
40
  ]
41
 
 
42
  search_results_df = pd.DataFrame([item.dict() for item in search_results_models])
43
 
44
  print("###")
 
50
  def rag_inference(openai_api_key, prompt, search_results):
51
  openai.api_key = openai_api_key
52
 
 
53
  content = f"Answer this user question: {prompt} with the following context:\n{search_results}"
54
 
55
  completion = openai.chat.completions.create(
 
65
  ]
66
  )
67
 
68
+ return completion.choices[0].message.content
 
 
 
 
 
 
69
 
70
  def process_records(data_frame):
71
  records = data_frame.to_dict(orient="records")
 
90
  return []
91
 
92
  def vector_search(openai_api_key, user_query, db, collection, additional_stages=[], vector_index="vector_index"):
93
+ query_embedding = get_text_embedding(openai_api_key, user_query)
94
+ query_embedding2 = get_image_embedding(openai_api_key, user_query)
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
  if query_embedding is None:
97
  return "Invalid query or embedding generation failed."
98
 
 
99
  vector_search_stage = {
100
  "$vectorSearch": {
101
+ "index": vector_index,
102
+ "queryVector": query_embedding,
103
+ "path": "text_embeddings",
104
+ "numCandidates": 150,
105
+ "limit": 20,
106
  "filter": {
107
  "$and": [
108
  {"accommodates": {"$eq": 2}},
 
112
  }
113
  }
114
 
 
115
  pipeline = [vector_search_stage] + additional_stages
116
 
 
117
  results = collection.aggregate(pipeline)
118
 
119
+ explain_query_execution = db.command(
120
+ "explain", {
121
+ "aggregate": collection.name,
122
+ "pipeline": pipeline,
123
+ "cursor": {}
124
  },
125
+ verbosity='executionStats')
126
 
127
+ vector_search_explain = explain_query_execution["stages"][0]["$vectorSearch"]
128
+ millis_elapsed = vector_search_explain["explain"]["collectStats"]["millisElapsed"]
129
+ print(f"Query execution time: {millis_elapsed} milliseconds")
 
 
 
 
 
 
130
 
131
  return list(results)
132
 
133
+ def get_text_embedding(openai_api_key, text):
 
 
 
134
  if not text or not isinstance(text, str):
135
  return None
136
 
 
141
  input=text,
142
  model="text-embedding-3-small", dimensions=1536).data[0].embedding
143
  return embedding
144
+ except Exception as e:
145
+ print(f"Error in get_embedding: {e}")
146
+ return None
147
+
148
+ def get_image_embedding(openai_api_key, text):
149
+ if not text or not isinstance(text, str):
150
+ return None
151
+
152
+ openai.api_key = openai_api_key
153
+
154
+ try:
155
+ embedding = openai.embeddings.create(
156
+ input=text,
157
+ model="openai/clip-vit-base-patch32", dimensions=512.data[0].embedding
158
+ return embedding
159
  except Exception as e:
160
  print(f"Error in get_embedding: {e}")
161
  return None