Update update_embeddings.py
Browse files- update_embeddings.py +80 -25
update_embeddings.py
CHANGED
@@ -16,11 +16,13 @@ import os # Folder and file creation
|
|
16 |
from tqdm import tqdm # Progress bar
|
17 |
tqdm.pandas() # Progress bar for pandas
|
18 |
from mixedbread_ai.client import MixedbreadAI # For embedding the text
|
|
|
19 |
import numpy as np # For array manipulation
|
20 |
from huggingface_hub import HfApi # To transact with huggingface.co
|
21 |
import sys # To quit the script
|
22 |
import datetime # get current year
|
23 |
from time import time, sleep # To time the script
|
|
|
24 |
|
25 |
# Start timer
|
26 |
start = time()
|
@@ -57,6 +59,12 @@ num_cores = cpu_count()-1
|
|
57 |
# Setup transaction details
|
58 |
repo_id = "bluuebunny/arxiv_abstract_embedding_mxbai_large_v1_milvus"
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
################################################################################
|
61 |
# Download the dataset
|
62 |
|
@@ -90,28 +98,35 @@ else:
|
|
90 |
# https://huggingface.co/docs/datasets/en/about_arrow#memory-mapping
|
91 |
# Load metadata
|
92 |
print(f"Loading json metadata")
|
93 |
-
|
94 |
-
|
95 |
-
########################################
|
96 |
-
# Function to add year to metadata
|
97 |
-
def add_year(example):
|
98 |
|
99 |
-
|
|
|
|
|
|
|
100 |
|
101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
########################################
|
103 |
|
104 |
# Add year to metadata
|
105 |
print(f"Adding year to metadata")
|
106 |
-
arxiv_metadata_all =
|
107 |
|
108 |
# Filter by year
|
109 |
print(f"Filtering metadata by year: {year}")
|
110 |
-
|
111 |
-
|
112 |
-
# Convert to pandas
|
113 |
-
print(f"Loading metadata for year: {year} into pandas")
|
114 |
-
arxiv_metadata_split = arxiv_metadata_all['train'].to_pandas()
|
115 |
|
116 |
################################################################################
|
117 |
# Load Model
|
@@ -131,8 +146,13 @@ if LOCAL:
|
|
131 |
else:
|
132 |
print("Setting up mxbai API client")
|
133 |
print("To use local resources, set LOCAL = True")
|
|
|
134 |
# Setup mxbai
|
135 |
-
|
|
|
|
|
|
|
|
|
136 |
mxbai = MixedbreadAI(api_key=mxbai_api_key)
|
137 |
|
138 |
########################################
|
@@ -142,10 +162,13 @@ def embed(input_text):
|
|
142 |
if LOCAL:
|
143 |
|
144 |
# Calculate embeddings by calling model.encode(), specifying the device
|
145 |
-
embedding = model.encode(input_text, device=device)
|
146 |
|
147 |
-
|
|
|
148 |
|
|
|
|
|
149 |
# Avoid rate limit from api
|
150 |
sleep(0.2)
|
151 |
|
@@ -158,7 +181,8 @@ def embed(input_text):
|
|
158 |
truncation_strategy='end'
|
159 |
)
|
160 |
|
161 |
-
|
|
|
162 |
|
163 |
return embedding
|
164 |
########################################
|
@@ -203,6 +227,9 @@ except Exception as e:
|
|
203 |
# Find papers that are not in the previous embeddings
|
204 |
new_papers = arxiv_metadata_split[~arxiv_metadata_split['id'].isin(previous_embeddings['id'])]
|
205 |
|
|
|
|
|
|
|
206 |
# Number of new papers
|
207 |
num_new_papers = len(new_papers)
|
208 |
|
@@ -216,17 +243,39 @@ if num_new_papers == 0:
|
|
216 |
print(f"Creating new embeddings for: {num_new_papers} entries")
|
217 |
new_papers["vector"] = new_papers["abstract"].progress_apply(embed)
|
218 |
|
219 |
-
|
220 |
-
new_papers.rename(columns={'title': 'Title', 'authors': 'Authors', 'abstract': 'Abstract'}, inplace=True)
|
221 |
-
|
222 |
# Add URL column
|
223 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
224 |
|
225 |
-
|
226 |
-
new_papers['$meta'] = new_papers[['Title', 'Authors', 'Abstract', 'URL']].apply(lambda row: json.dumps(row.to_dict()), axis=1)
|
227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
# Selecting id, vector and $meta to retain
|
229 |
-
selected_columns = ['id', 'vector', '
|
230 |
|
231 |
# Merge previous embeddings and new embeddings
|
232 |
new_embeddings = pd.concat([previous_embeddings, new_papers[selected_columns]])
|
@@ -248,7 +297,13 @@ new_embeddings.to_parquet(embed_filename, index=False)
|
|
248 |
if UPLOAD:
|
249 |
|
250 |
print(f"Uploading new embeddings to: {repo_id}")
|
251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
api = HfApi(token=access_token)
|
253 |
|
254 |
# Upload all files within the folder to the specified repository
|
|
|
16 |
from tqdm import tqdm # Progress bar
|
17 |
tqdm.pandas() # Progress bar for pandas
|
18 |
from mixedbread_ai.client import MixedbreadAI # For embedding the text
|
19 |
+
from dotenv import dotenv_values # To load environment variables
|
20 |
import numpy as np # For array manipulation
|
21 |
from huggingface_hub import HfApi # To transact with huggingface.co
|
22 |
import sys # To quit the script
|
23 |
import datetime # get current year
|
24 |
from time import time, sleep # To time the script
|
25 |
+
from datetime import datetime # To get the current date and time
|
26 |
|
27 |
# Start timer
|
28 |
start = time()
|
|
|
59 |
# Setup transaction details
|
60 |
repo_id = "bluuebunny/arxiv_abstract_embedding_mxbai_large_v1_milvus"
|
61 |
|
62 |
+
# Import secrets
|
63 |
+
config = dotenv_values(".env")
|
64 |
+
|
65 |
+
def is_running_in_huggingface_space():
|
66 |
+
return "SPACE_ID" in os.environ
|
67 |
+
|
68 |
################################################################################
|
69 |
# Download the dataset
|
70 |
|
|
|
98 |
# https://huggingface.co/docs/datasets/en/about_arrow#memory-mapping
|
99 |
# Load metadata
|
100 |
print(f"Loading json metadata")
|
101 |
+
dataset = load_dataset("json", data_files= str(f"{download_file}"))
|
|
|
|
|
|
|
|
|
102 |
|
103 |
+
# Split metadata by year
|
104 |
+
# Convert to pandas
|
105 |
+
print(f"Converting metadata into pandas")
|
106 |
+
arxiv_metadata_all = dataset['train'].to_pandas()
|
107 |
|
108 |
+
########################################
|
109 |
+
# Function to extract year from arxiv id
|
110 |
+
# https://info.arxiv.org/help/arxiv_identifier.html
|
111 |
+
# Function to extract Month and year of publication using arxiv ID
|
112 |
+
def extract_month_year(arxiv_id, what='month'):
|
113 |
+
# Identify the relevant YYMM part based on the arXiv ID format
|
114 |
+
yymm = arxiv_id.split('/')[-1][:4] if '/' in arxiv_id else arxiv_id.split('.')[0]
|
115 |
+
|
116 |
+
# Convert the year-month string to a datetime object
|
117 |
+
date = datetime.strptime(yymm, '%y%m')
|
118 |
+
|
119 |
+
# Return the desired part based on the input parameter
|
120 |
+
return date.strftime('%B') if what == 'month' else date.strftime('%Y')
|
121 |
########################################
|
122 |
|
123 |
# Add year to metadata
|
124 |
print(f"Adding year to metadata")
|
125 |
+
arxiv_metadata_all['year'] = arxiv_metadata_all['id'].progress_apply(extract_month_year, what='year')
|
126 |
|
127 |
# Filter by year
|
128 |
print(f"Filtering metadata by year: {year}")
|
129 |
+
arxiv_metadata_split = arxiv_metadata_all[arxiv_metadata_all['year'] == year]
|
|
|
|
|
|
|
|
|
130 |
|
131 |
################################################################################
|
132 |
# Load Model
|
|
|
146 |
else:
|
147 |
print("Setting up mxbai API client")
|
148 |
print("To use local resources, set LOCAL = True")
|
149 |
+
|
150 |
# Setup mxbai
|
151 |
+
if is_running_in_huggingface_space():
|
152 |
+
mxbai_api_key = os.getenv("MXBAI_API_KEY")
|
153 |
+
else:
|
154 |
+
mxbai_api_key = config["MXBAI_API_KEY"]
|
155 |
+
|
156 |
mxbai = MixedbreadAI(api_key=mxbai_api_key)
|
157 |
|
158 |
########################################
|
|
|
162 |
if LOCAL:
|
163 |
|
164 |
# Calculate embeddings by calling model.encode(), specifying the device
|
165 |
+
embedding = model.encode(input_text, device=device, precision="float32")
|
166 |
|
167 |
+
# Enforce 32-bit float precision
|
168 |
+
embedding = np.array(embedding, dtype=np.float32)
|
169 |
|
170 |
+
else:
|
171 |
+
|
172 |
# Avoid rate limit from api
|
173 |
sleep(0.2)
|
174 |
|
|
|
181 |
truncation_strategy='end'
|
182 |
)
|
183 |
|
184 |
+
# Enforce 32-bit float precision
|
185 |
+
embedding = np.array(result.data[0].embedding, dtype=np.float32)
|
186 |
|
187 |
return embedding
|
188 |
########################################
|
|
|
227 |
# Find papers that are not in the previous embeddings
|
228 |
new_papers = arxiv_metadata_split[~arxiv_metadata_split['id'].isin(previous_embeddings['id'])]
|
229 |
|
230 |
+
# Drop duplicates based on the 'id' column
|
231 |
+
new_papers = new_papers.drop_duplicates(subset='id', keep='last', ignore_index=True)
|
232 |
+
|
233 |
# Number of new papers
|
234 |
num_new_papers = len(new_papers)
|
235 |
|
|
|
243 |
print(f"Creating new embeddings for: {num_new_papers} entries")
|
244 |
new_papers["vector"] = new_papers["abstract"].progress_apply(embed)
|
245 |
|
246 |
+
####################
|
|
|
|
|
247 |
# Add URL column
|
248 |
+
arxiv_metadata_split['url'] = 'https://arxiv.org/abs/' + arxiv_metadata_split['id']
|
249 |
+
|
250 |
+
# Add month column
|
251 |
+
arxiv_metadata_split['month'] = arxiv_metadata_split['id'].progress_apply(extract_month_year, what='month')
|
252 |
+
|
253 |
+
####################
|
254 |
+
# Remove newline characters from authors, title, abstract and categories columns
|
255 |
+
arxiv_metadata_split['title'] = arxiv_metadata_split['title'].astype(str).str.replace('\n', ' ', regex=False)
|
256 |
+
|
257 |
+
arxiv_metadata_split['authors'] = arxiv_metadata_split['authors'].astype(str).str.replace('\n', ' ', regex=False)
|
258 |
+
|
259 |
+
arxiv_metadata_split['categories'] = arxiv_metadata_split['categories'].astype(str).str.replace('\n', ' ', regex=False)
|
260 |
|
261 |
+
arxiv_metadata_split['abstract'] = arxiv_metadata_split['abstract'].astype(str).str.replace('\n', ' ', regex=False)
|
|
|
262 |
|
263 |
+
####################
|
264 |
+
# Trim title to 512 characters
|
265 |
+
arxiv_metadata_split['title'] = arxiv_metadata_split['title'].progress_apply(lambda x: x[:508] + '...' if len(x) > 512 else x)
|
266 |
+
|
267 |
+
# Trim categories to 128 characters
|
268 |
+
arxiv_metadata_split['categories'] = arxiv_metadata_split['categories'].progress_apply(lambda x: x[:124] + '...' if len(x) > 128 else x)
|
269 |
+
|
270 |
+
# Trim authors to 128 characters
|
271 |
+
arxiv_metadata_split['authors'] = arxiv_metadata_split['authors'].progress_apply(lambda x: x[:124] + '...' if len(x) > 128 else x)
|
272 |
+
|
273 |
+
# Trim abstract to 3072 characters
|
274 |
+
arxiv_metadata_split['abstract'] = arxiv_metadata_split['abstract'].progress_apply(lambda x: x[:3068] + '...' if len(x) > 3072 else x)
|
275 |
+
|
276 |
+
####################
|
277 |
# Selecting id, vector and $meta to retain
|
278 |
+
selected_columns = ['id', 'vector', 'title', 'abstract', 'authors', 'categories', 'month', 'year', 'url']
|
279 |
|
280 |
# Merge previous embeddings and new embeddings
|
281 |
new_embeddings = pd.concat([previous_embeddings, new_papers[selected_columns]])
|
|
|
297 |
if UPLOAD:
|
298 |
|
299 |
print(f"Uploading new embeddings to: {repo_id}")
|
300 |
+
|
301 |
+
# Setup Hugging Face API
|
302 |
+
if is_running_in_huggingface_space():
|
303 |
+
access_token = os.getenv("HF_API_KEY")
|
304 |
+
else:
|
305 |
+
access_token = config["HF_API_KEY"]
|
306 |
+
|
307 |
api = HfApi(token=access_token)
|
308 |
|
309 |
# Upload all files within the folder to the specified repository
|