bluuebunny commited on
Commit
539851c
·
verified ·
1 Parent(s): a8dedcf

Update update_embeddings.py

Browse files
Files changed (1) hide show
  1. update_embeddings.py +80 -25
update_embeddings.py CHANGED
@@ -16,11 +16,13 @@ import os # Folder and file creation
16
  from tqdm import tqdm # Progress bar
17
  tqdm.pandas() # Progress bar for pandas
18
  from mixedbread_ai.client import MixedbreadAI # For embedding the text
 
19
  import numpy as np # For array manipulation
20
  from huggingface_hub import HfApi # To transact with huggingface.co
21
  import sys # To quit the script
22
  import datetime # get current year
23
  from time import time, sleep # To time the script
 
24
 
25
  # Start timer
26
  start = time()
@@ -57,6 +59,12 @@ num_cores = cpu_count()-1
57
  # Setup transaction details
58
  repo_id = "bluuebunny/arxiv_abstract_embedding_mxbai_large_v1_milvus"
59
 
 
 
 
 
 
 
60
  ################################################################################
61
  # Download the dataset
62
 
@@ -90,28 +98,35 @@ else:
90
  # https://huggingface.co/docs/datasets/en/about_arrow#memory-mapping
91
  # Load metadata
92
  print(f"Loading json metadata")
93
- arxiv_metadata_all = load_dataset("json", data_files= str(f"{download_file}"))
94
-
95
- ########################################
96
- # Function to add year to metadata
97
- def add_year(example):
98
 
99
- example['year'] = example['id'].split('/')[1][:2] if '/' in example['id'] else example['id'][:2]
 
 
 
100
 
101
- return example
 
 
 
 
 
 
 
 
 
 
 
 
102
  ########################################
103
 
104
  # Add year to metadata
105
  print(f"Adding year to metadata")
106
- arxiv_metadata_all = arxiv_metadata_all.map(add_year, num_proc=num_cores)
107
 
108
  # Filter by year
109
  print(f"Filtering metadata by year: {year}")
110
- arxiv_metadata_all = arxiv_metadata_all.filter(lambda example: example['year'] == year, num_proc=num_cores)
111
-
112
- # Convert to pandas
113
- print(f"Loading metadata for year: {year} into pandas")
114
- arxiv_metadata_split = arxiv_metadata_all['train'].to_pandas()
115
 
116
  ################################################################################
117
  # Load Model
@@ -131,8 +146,13 @@ if LOCAL:
131
  else:
132
  print("Setting up mxbai API client")
133
  print("To use local resources, set LOCAL = True")
 
134
  # Setup mxbai
135
- mxbai_api_key = os.getenv("MXBAI_API_KEY")
 
 
 
 
136
  mxbai = MixedbreadAI(api_key=mxbai_api_key)
137
 
138
  ########################################
@@ -142,10 +162,13 @@ def embed(input_text):
142
  if LOCAL:
143
 
144
  # Calculate embeddings by calling model.encode(), specifying the device
145
- embedding = model.encode(input_text, device=device)
146
 
147
- else:
 
148
 
 
 
149
  # Avoid rate limit from api
150
  sleep(0.2)
151
 
@@ -158,7 +181,8 @@ def embed(input_text):
158
  truncation_strategy='end'
159
  )
160
 
161
- embedding = np.array(result.data[0].embedding)
 
162
 
163
  return embedding
164
  ########################################
@@ -203,6 +227,9 @@ except Exception as e:
203
  # Find papers that are not in the previous embeddings
204
  new_papers = arxiv_metadata_split[~arxiv_metadata_split['id'].isin(previous_embeddings['id'])]
205
 
 
 
 
206
  # Number of new papers
207
  num_new_papers = len(new_papers)
208
 
@@ -216,17 +243,39 @@ if num_new_papers == 0:
216
  print(f"Creating new embeddings for: {num_new_papers} entries")
217
  new_papers["vector"] = new_papers["abstract"].progress_apply(embed)
218
 
219
- # Rename columns
220
- new_papers.rename(columns={'title': 'Title', 'authors': 'Authors', 'abstract': 'Abstract'}, inplace=True)
221
-
222
  # Add URL column
223
- new_papers['URL'] = 'https://arxiv.org/abs/' + new_papers['id']
 
 
 
 
 
 
 
 
 
 
 
224
 
225
- # Create milvus compatible parquet file, $meta is a json string of the metadata
226
- new_papers['$meta'] = new_papers[['Title', 'Authors', 'Abstract', 'URL']].apply(lambda row: json.dumps(row.to_dict()), axis=1)
227
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  # Selecting id, vector and $meta to retain
229
- selected_columns = ['id', 'vector', '$meta']
230
 
231
  # Merge previous embeddings and new embeddings
232
  new_embeddings = pd.concat([previous_embeddings, new_papers[selected_columns]])
@@ -248,7 +297,13 @@ new_embeddings.to_parquet(embed_filename, index=False)
248
  if UPLOAD:
249
 
250
  print(f"Uploading new embeddings to: {repo_id}")
251
- access_token = os.getenv("HF_API_KEY")
 
 
 
 
 
 
252
  api = HfApi(token=access_token)
253
 
254
  # Upload all files within the folder to the specified repository
 
16
  from tqdm import tqdm # Progress bar
17
  tqdm.pandas() # Progress bar for pandas
18
  from mixedbread_ai.client import MixedbreadAI # For embedding the text
19
+ from dotenv import dotenv_values # To load environment variables
20
  import numpy as np # For array manipulation
21
  from huggingface_hub import HfApi # To transact with huggingface.co
22
  import sys # To quit the script
23
  import datetime # get current year
24
  from time import time, sleep # To time the script
25
+ from datetime import datetime # To get the current date and time
26
 
27
  # Start timer
28
  start = time()
 
59
  # Setup transaction details
60
  repo_id = "bluuebunny/arxiv_abstract_embedding_mxbai_large_v1_milvus"
61
 
62
+ # Import secrets
63
+ config = dotenv_values(".env")
64
+
65
+ def is_running_in_huggingface_space():
66
+ return "SPACE_ID" in os.environ
67
+
68
  ################################################################################
69
  # Download the dataset
70
 
 
98
  # https://huggingface.co/docs/datasets/en/about_arrow#memory-mapping
99
  # Load metadata
100
  print(f"Loading json metadata")
101
+ dataset = load_dataset("json", data_files= str(f"{download_file}"))
 
 
 
 
102
 
103
+ # Split metadata by year
104
+ # Convert to pandas
105
+ print(f"Converting metadata into pandas")
106
+ arxiv_metadata_all = dataset['train'].to_pandas()
107
 
108
+ ########################################
109
+ # Function to extract year from arxiv id
110
+ # https://info.arxiv.org/help/arxiv_identifier.html
111
+ # Function to extract Month and year of publication using arxiv ID
112
+ def extract_month_year(arxiv_id, what='month'):
113
+ # Identify the relevant YYMM part based on the arXiv ID format
114
+ yymm = arxiv_id.split('/')[-1][:4] if '/' in arxiv_id else arxiv_id.split('.')[0]
115
+
116
+ # Convert the year-month string to a datetime object
117
+ date = datetime.strptime(yymm, '%y%m')
118
+
119
+ # Return the desired part based on the input parameter
120
+ return date.strftime('%B') if what == 'month' else date.strftime('%Y')
121
  ########################################
122
 
123
  # Add year to metadata
124
  print(f"Adding year to metadata")
125
+ arxiv_metadata_all['year'] = arxiv_metadata_all['id'].progress_apply(extract_month_year, what='year')
126
 
127
  # Filter by year
128
  print(f"Filtering metadata by year: {year}")
129
+ arxiv_metadata_split = arxiv_metadata_all[arxiv_metadata_all['year'] == year]
 
 
 
 
130
 
131
  ################################################################################
132
  # Load Model
 
146
  else:
147
  print("Setting up mxbai API client")
148
  print("To use local resources, set LOCAL = True")
149
+
150
  # Setup mxbai
151
+ if is_running_in_huggingface_space():
152
+ mxbai_api_key = os.getenv("MXBAI_API_KEY")
153
+ else:
154
+ mxbai_api_key = config["MXBAI_API_KEY"]
155
+
156
  mxbai = MixedbreadAI(api_key=mxbai_api_key)
157
 
158
  ########################################
 
162
  if LOCAL:
163
 
164
  # Calculate embeddings by calling model.encode(), specifying the device
165
+ embedding = model.encode(input_text, device=device, precision="float32")
166
 
167
+ # Enforce 32-bit float precision
168
+ embedding = np.array(embedding, dtype=np.float32)
169
 
170
+ else:
171
+
172
  # Avoid rate limit from api
173
  sleep(0.2)
174
 
 
181
  truncation_strategy='end'
182
  )
183
 
184
+ # Enforce 32-bit float precision
185
+ embedding = np.array(result.data[0].embedding, dtype=np.float32)
186
 
187
  return embedding
188
  ########################################
 
227
  # Find papers that are not in the previous embeddings
228
  new_papers = arxiv_metadata_split[~arxiv_metadata_split['id'].isin(previous_embeddings['id'])]
229
 
230
+ # Drop duplicates based on the 'id' column
231
+ new_papers = new_papers.drop_duplicates(subset='id', keep='last', ignore_index=True)
232
+
233
  # Number of new papers
234
  num_new_papers = len(new_papers)
235
 
 
243
  print(f"Creating new embeddings for: {num_new_papers} entries")
244
  new_papers["vector"] = new_papers["abstract"].progress_apply(embed)
245
 
246
+ ####################
 
 
247
  # Add URL column
248
+ arxiv_metadata_split['url'] = 'https://arxiv.org/abs/' + arxiv_metadata_split['id']
249
+
250
+ # Add month column
251
+ arxiv_metadata_split['month'] = arxiv_metadata_split['id'].progress_apply(extract_month_year, what='month')
252
+
253
+ ####################
254
+ # Remove newline characters from authors, title, abstract and categories columns
255
+ arxiv_metadata_split['title'] = arxiv_metadata_split['title'].astype(str).str.replace('\n', ' ', regex=False)
256
+
257
+ arxiv_metadata_split['authors'] = arxiv_metadata_split['authors'].astype(str).str.replace('\n', ' ', regex=False)
258
+
259
+ arxiv_metadata_split['categories'] = arxiv_metadata_split['categories'].astype(str).str.replace('\n', ' ', regex=False)
260
 
261
+ arxiv_metadata_split['abstract'] = arxiv_metadata_split['abstract'].astype(str).str.replace('\n', ' ', regex=False)
 
262
 
263
+ ####################
264
+ # Trim title to 512 characters
265
+ arxiv_metadata_split['title'] = arxiv_metadata_split['title'].progress_apply(lambda x: x[:508] + '...' if len(x) > 512 else x)
266
+
267
+ # Trim categories to 128 characters
268
+ arxiv_metadata_split['categories'] = arxiv_metadata_split['categories'].progress_apply(lambda x: x[:124] + '...' if len(x) > 128 else x)
269
+
270
+ # Trim authors to 128 characters
271
+ arxiv_metadata_split['authors'] = arxiv_metadata_split['authors'].progress_apply(lambda x: x[:124] + '...' if len(x) > 128 else x)
272
+
273
+ # Trim abstract to 3072 characters
274
+ arxiv_metadata_split['abstract'] = arxiv_metadata_split['abstract'].progress_apply(lambda x: x[:3068] + '...' if len(x) > 3072 else x)
275
+
276
+ ####################
277
  # Selecting id, vector and $meta to retain
278
+ selected_columns = ['id', 'vector', 'title', 'abstract', 'authors', 'categories', 'month', 'year', 'url']
279
 
280
  # Merge previous embeddings and new embeddings
281
  new_embeddings = pd.concat([previous_embeddings, new_papers[selected_columns]])
 
297
  if UPLOAD:
298
 
299
  print(f"Uploading new embeddings to: {repo_id}")
300
+
301
+ # Setup Hugging Face API
302
+ if is_running_in_huggingface_space():
303
+ access_token = os.getenv("HF_API_KEY")
304
+ else:
305
+ access_token = config["HF_API_KEY"]
306
+
307
  api = HfApi(token=access_token)
308
 
309
  # Upload all files within the folder to the specified repository