File size: 12,993 Bytes
d3c1ddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
539851c
d3c1ddf
 
e4e89fb
0ee2db9
abc0ea3
539851c
0ee2db9
 
 
d3c1ddf
 
 
 
c5b8946
 
d3c1ddf
 
 
 
 
 
 
 
 
 
5cd10ec
 
 
85dfa22
 
 
 
 
 
 
 
5cd10ec
 
d3c1ddf
 
 
 
 
 
 
 
 
539851c
 
 
 
 
 
d3c1ddf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ee2db9
 
d3c1ddf
 
 
 
 
 
 
0ee2db9
 
d3c1ddf
 
 
 
 
 
 
539851c
d3c1ddf
539851c
 
 
 
d3c1ddf
539851c
 
 
 
 
 
 
 
 
 
 
 
 
d3c1ddf
 
 
 
539851c
d3c1ddf
 
 
539851c
d3c1ddf
 
 
 
 
0ee2db9
 
 
 
d3c1ddf
 
 
 
 
 
 
 
0ee2db9
 
539851c
d3c1ddf
539851c
 
 
 
 
d3c1ddf
 
 
 
 
 
 
 
 
539851c
d3c1ddf
539851c
 
d3c1ddf
539851c
 
5cd10ec
abc0ea3
 
d3c1ddf
 
 
 
 
 
 
 
 
539851c
 
d3c1ddf
 
 
 
 
 
 
5cd10ec
 
 
 
 
 
 
 
 
 
d3c1ddf
 
 
 
 
 
0ee2db9
 
 
 
d3c1ddf
0ee2db9
 
 
 
 
 
 
 
c5b8946
d3c1ddf
 
 
 
 
 
 
539851c
 
 
e8bb333
 
 
 
 
 
 
 
 
d3c1ddf
e8bb333
d3c1ddf
 
539851c
85dfa22
 
d3c1ddf
85dfa22
539851c
 
85dfa22
539851c
 
85dfa22
 
539851c
85dfa22
539851c
85dfa22
539851c
85dfa22
d3c1ddf
85dfa22
d3c1ddf
539851c
85dfa22
 
539851c
85dfa22
539851c
 
85dfa22
539851c
 
85dfa22
539851c
 
85dfa22
539851c
 
85dfa22
 
d3c1ddf
539851c
d3c1ddf
 
 
 
5cd10ec
 
 
 
d3c1ddf
 
 
 
 
 
 
 
 
 
 
0ee2db9
d3c1ddf
539851c
 
 
 
 
 
 
d3c1ddf
 
 
0ee2db9
 
 
 
 
 
 
 
 
5cd10ec
 
 
 
 
 
 
 
b321e29
5cd10ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ee2db9
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
## Download the arXiv metadata from Kaggle
## https://www.kaggle.com/datasets/Cornell-University/arxiv

## Requires the Kaggle API to be installed
## Using subprocess to run the Kaggle CLI commands instead of Kaggle API
## As it allows for anonymous downloads without needing to sign in
import subprocess
from datasets import load_dataset # To load dataset without breaking ram
from multiprocessing import cpu_count # To get the number of cores
from sentence_transformers import SentenceTransformer # For embedding the text
import torch # For gpu 
import pandas as pd # Data manipulation
from huggingface_hub import snapshot_download # Download previous embeddings
import os # Folder and file creation
from tqdm import tqdm # Progress bar
tqdm.pandas() # Progress bar for pandas
from mixedbread_ai.client import MixedbreadAI # For embedding the text
from dotenv import dotenv_values # To load environment variables
import numpy as np # For array manipulation
from huggingface_hub import HfApi # To transact with huggingface.co
import sys # To quit the script
import datetime # get current year
from time import time, sleep # To time the script
from datetime import datetime # To get the current date and time

# Start timer
start = time()

################################################################################
# Configuration

# Get current year
year = str(datetime.now().year)

# Flag to force download and conversion even if files already exist
FORCE = True

# Flag to embed the data locally, otherwise it will use mxbai api to embed
LOCAL = False

# Flag to upload the data to the Hugging Face Hub
UPLOAD = True

# Flag to binarise the data
BINARY = True

# Print the configuration
print(f'Configuration:')
print(f'Year: {year}')
print(f'Force: {FORCE}')
print(f'Local: {LOCAL}')
print(f'Upload: {UPLOAD}')
print(f'Binary: {BINARY}')

########################################

# Model to use for embedding
model_name = "mixedbread-ai/mxbai-embed-large-v1"

# Number of cores to use for multiprocessing
num_cores = cpu_count()-1

# Setup transaction details
repo_id = "bluuebunny/arxiv_abstract_embedding_mxbai_large_v1_milvus"

# Import secrets
config = dotenv_values(".env")

def is_running_in_huggingface_space():
    return "SPACE_ID" in os.environ

################################################################################
# Download the dataset

# Dataset name
dataset_path = 'Cornell-University/arxiv'

# Download folder
download_folder = 'data'

# Data file path
download_file = f'{download_folder}/arxiv-metadata-oai-snapshot.json'

## Download the dataset if it doesn't exist
if not os.path.exists(download_file) or FORCE:

    print(f'Downloading {download_file}, if it exists it will be overwritten')
    print('Set FORCE to False to skip download if file already exists')

    subprocess.run(['kaggle', 'datasets', 'download', '--dataset', dataset_path, '--path', download_folder, '--unzip'])
    
    print(f'Downloaded {download_file}')

else:

    print(f'{download_file} already exists, skipping download')
    print('Set FORCE = True to force download')

################################################################################
# Filter by year and convert to parquet

# https://huggingface.co/docs/datasets/en/about_arrow#memory-mapping
# Load metadata
print(f"Loading json metadata")
dataset = load_dataset("json", data_files= str(f"{download_file}"))

# Split metadata by year
# Convert to pandas
print(f"Converting metadata into pandas")
arxiv_metadata_all = dataset['train'].to_pandas()

########################################
# Function to extract year from arxiv id
# https://info.arxiv.org/help/arxiv_identifier.html
# Function to extract Month and year of publication using arxiv ID
def extract_month_year(arxiv_id, what='month'):
    # Identify the relevant YYMM part based on the arXiv ID format
    yymm = arxiv_id.split('/')[-1][:4] if '/' in arxiv_id else arxiv_id.split('.')[0]
    
    # Convert the year-month string to a datetime object
    date = datetime.strptime(yymm, '%y%m')
    
    # Return the desired part based on the input parameter
    return date.strftime('%B') if what == 'month' else date.strftime('%Y')
########################################

# Add year to metadata
print(f"Adding year to metadata")
arxiv_metadata_all['year'] =  arxiv_metadata_all['id'].progress_apply(extract_month_year, what='year')

# Filter by year
print(f"Filtering metadata by year: {year}")
arxiv_metadata_split = arxiv_metadata_all[arxiv_metadata_all['year'] == year]

################################################################################
# Load Model

if LOCAL:

    print(f"Setting up local embedding model")
    print("To use mxbai API, set LOCAL = False")

    # Make the app device agnostic
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    # Load a pretrained Sentence Transformer model and move it to the appropriate device
    print(f"Loading model {model_name} to device: {device}")
    model = SentenceTransformer(model_name)
    model = model.to(device)
else:
    print("Setting up mxbai API client")
    print("To use local resources, set LOCAL = True")

    # Setup mxbai
    if is_running_in_huggingface_space():
        mxbai_api_key = os.getenv("MXBAI_API_KEY")
    else:
        mxbai_api_key = config["MXBAI_API_KEY"]

    mxbai = MixedbreadAI(api_key=mxbai_api_key)

########################################
# Function that does the embedding
def embed(input_text):
    
    if LOCAL:

        # Calculate embeddings by calling model.encode(), specifying the device
        embedding = model.encode(input_text, device=device, precision="float32")

        # Enforce 32-bit float precision
        embedding = np.array(embedding, dtype=np.float32)

    else:
        
        # Avoid rate limit from api
        sleep(0.2)

        # Calculate embeddings by calling mxbai.embeddings()
        result = mxbai.embeddings(
        model='mixedbread-ai/mxbai-embed-large-v1',
        input=input_text,
        normalized=True,
        encoding_format='float',
        truncation_strategy='end'
        )

        # Enforce 32-bit float precision
        embedding = np.array(result.data[0].embedding, dtype=np.float32)

    return embedding
########################################

################################################################################
# Gather preexisting embeddings

# Subfolder in the repo of the dataset where the file is stored
folder_in_repo = "data"
allow_patterns = f"{folder_in_repo}/{year}.parquet"

# Where to store the local copy of the dataset
local_dir = repo_id

# Set repo type
repo_type = "dataset"

# Create local directory
os.makedirs(local_dir, exist_ok=True)

# Download the repo
snapshot_download(repo_id=repo_id, repo_type=repo_type, local_dir=local_dir, allow_patterns=allow_patterns)

try:

    # Gather previous embed file
    previous_embed = f'{local_dir}/{folder_in_repo}/{year}.parquet'

    # Load previous_embed
    print(f"Loading previously embedded file: {previous_embed}")   
    previous_embeddings = pd.read_parquet(previous_embed)

except Exception as e:
    print(f"Errored out with: {e}")
    print(f"No previous embeddings found for year: {year}")
    print("Creating new embeddings for all papers")
    previous_embeddings = pd.DataFrame(columns=['id', 'vector', 'title', 'abstract', 'authors', 'categories', 'month', 'year', 'url'])

########################################
# Embed the new abstracts

# Find papers that are not in the previous embeddings
new_papers = arxiv_metadata_split[~arxiv_metadata_split['id'].isin(previous_embeddings['id'])]

# Drop duplicates based on the 'id' column
new_papers = new_papers.drop_duplicates(subset='id', keep='last', ignore_index=True)

# Number of new papers
num_new_papers = len(new_papers)

# What if there are no new papers?
if num_new_papers == 0:
    print(f"No new papers found for year: {year}")
    print("Exiting")
    sys.exit()

# Create a column for embeddings
print(f"Creating new embeddings for: {num_new_papers} entries")
new_papers["vector"] = new_papers["abstract"].progress_apply(embed)

####################
print("Adding url and month columns")

# Add URL column
new_papers['url'] = 'https://arxiv.org/abs/' + new_papers['id']

# Add month column
new_papers['month'] = new_papers['id'].progress_apply(extract_month_year, what='month')

####################
print("Removing newline characters from title, authors, categories, abstract")

# Remove newline characters from authors, title, abstract and categories columns
new_papers['title'] = new_papers['title'].astype(str).str.replace('\n', ' ', regex=False)

new_papers['authors'] = new_papers['authors'].astype(str).str.replace('\n', ' ', regex=False)

new_papers['categories'] = new_papers['categories'].astype(str).str.replace('\n', ' ', regex=False)

new_papers['abstract'] = new_papers['abstract'].astype(str).str.replace('\n', ' ', regex=False)

####################
print("Trimming title, authors, categories, abstract")

# Trim title to 512 characters
new_papers['title'] = new_papers['title'].progress_apply(lambda x: x[:508] + '...' if len(x) > 512 else x)

# Trim categories to 128 characters
new_papers['categories'] = new_papers['categories'].progress_apply(lambda x: x[:124] + '...' if len(x) > 128 else x)

# Trim authors to 128 characters
new_papers['authors'] = new_papers['authors'].progress_apply(lambda x: x[:124] + '...' if len(x) > 128 else x)

# Trim abstract to 3072 characters
new_papers['abstract'] = new_papers['abstract'].progress_apply(lambda x: x[:3068] + '...' if len(x) > 3072 else x)

####################
print("Concatenating previouly embedded dataframe with new embeddings")

# Selecting id, vector and $meta to retain
selected_columns = ['id', 'vector', 'title', 'abstract', 'authors', 'categories', 'month', 'year', 'url']

# Merge previous embeddings and new embeddings
new_embeddings = pd.concat([previous_embeddings, new_papers[selected_columns]])

# Create embed folder
embed_folder = f"{year}-diff-embed"
os.makedirs(embed_folder, exist_ok=True)

# Save the embedded file
embed_filename = f'{embed_folder}/{year}.parquet'
print(f"Saving newly embedded dataframe to: {embed_filename}")
# Keeping index=False to avoid saving the index column as a separate column in the parquet file
# This keeps milvus from throwing an error when importing the parquet file
new_embeddings.to_parquet(embed_filename, index=False)

################################################################################

# Upload the new embeddings to the repo
if UPLOAD:

    print(f"Uploading new embeddings to: {repo_id}")

    # Setup Hugging Face API
    if is_running_in_huggingface_space():
        access_token = os.getenv("HF_API_KEY")
    else:
        access_token =  config["HF_API_KEY"]

    api = HfApi(token=access_token)

    # Upload all files within the folder to the specified repository
    api.upload_folder(repo_id=repo_id, folder_path=embed_folder, path_in_repo=folder_in_repo, repo_type="dataset")

    print(f"Upload complete for year: {year}")

else:
    print("Not uploading new embeddings to the repo")
    print("To upload new embeddings, set UPLOAD to True")
################################################################################

# Binarise the data
if BINARY:

    print(f"Binarising the data for year: {year}")
    print("Set BINARY = False to not binarise the embeddings")

    # Function to convert dense vector to binary vector
    def dense_to_binary(dense_vector):
        return np.packbits(np.where(dense_vector >= 0, 1, 0)).tobytes()

    # Create a folder to store binary embeddings
    binary_folder = f"{year}-binary-embed"
    os.makedirs(binary_folder, exist_ok=True)

    # Convert the dense vectors to binary vectors
    new_embeddings['vector'] = new_embeddings['vector'].progress_apply(dense_to_binary)

    # Save the binary embeddings to a parquet file
    new_embeddings.to_parquet(f'{binary_folder}/{year}.parquet', index=False)

if BINARY and UPLOAD:

    # Setup transaction details
    repo_id = "bluuebunny/arxiv_abstract_embedding_mxbai_large_v1_milvus_binary"
    repo_type = "dataset"

    api.create_repo(repo_id=repo_id, repo_type=repo_type, exist_ok=True)

    # Subfolder in the repo of the dataset where the file is stored
    folder_in_repo = "data"

    print(f"Uploading binary embeddings to {repo_id} from folder {binary_folder}")

    # Upload all files within the folder to the specified repository
    api.upload_folder(repo_id=repo_id, folder_path=binary_folder, path_in_repo=folder_in_repo, repo_type=repo_type)

    print("Upload complete")

else:
    print("Not uploading Binary embeddings to the repo")
    print("To upload embeddings, set UPLOAD and BINARY both to True")

################################################################################

# Track time
end = time()

# Calculate and show time taken
print(f"Time taken: {end - start} seconds")

print("Done!")