File size: 12,993 Bytes
d3c1ddf 539851c d3c1ddf e4e89fb 0ee2db9 abc0ea3 539851c 0ee2db9 d3c1ddf c5b8946 d3c1ddf 5cd10ec 85dfa22 5cd10ec d3c1ddf 539851c d3c1ddf 0ee2db9 d3c1ddf 0ee2db9 d3c1ddf 539851c d3c1ddf 539851c d3c1ddf 539851c d3c1ddf 539851c d3c1ddf 539851c d3c1ddf 0ee2db9 d3c1ddf 0ee2db9 539851c d3c1ddf 539851c d3c1ddf 539851c d3c1ddf 539851c d3c1ddf 539851c 5cd10ec abc0ea3 d3c1ddf 539851c d3c1ddf 5cd10ec d3c1ddf 0ee2db9 d3c1ddf 0ee2db9 c5b8946 d3c1ddf 539851c e8bb333 d3c1ddf e8bb333 d3c1ddf 539851c 85dfa22 d3c1ddf 85dfa22 539851c 85dfa22 539851c 85dfa22 539851c 85dfa22 539851c 85dfa22 539851c 85dfa22 d3c1ddf 85dfa22 d3c1ddf 539851c 85dfa22 539851c 85dfa22 539851c 85dfa22 539851c 85dfa22 539851c 85dfa22 539851c 85dfa22 d3c1ddf 539851c d3c1ddf 5cd10ec d3c1ddf 0ee2db9 d3c1ddf 539851c d3c1ddf 0ee2db9 5cd10ec b321e29 5cd10ec 0ee2db9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 |
## Download the arXiv metadata from Kaggle
## https://www.kaggle.com/datasets/Cornell-University/arxiv
## Requires the Kaggle API to be installed
## Using subprocess to run the Kaggle CLI commands instead of Kaggle API
## As it allows for anonymous downloads without needing to sign in
import subprocess
from datasets import load_dataset # To load dataset without breaking ram
from multiprocessing import cpu_count # To get the number of cores
from sentence_transformers import SentenceTransformer # For embedding the text
import torch # For gpu
import pandas as pd # Data manipulation
from huggingface_hub import snapshot_download # Download previous embeddings
import os # Folder and file creation
from tqdm import tqdm # Progress bar
tqdm.pandas() # Progress bar for pandas
from mixedbread_ai.client import MixedbreadAI # For embedding the text
from dotenv import dotenv_values # To load environment variables
import numpy as np # For array manipulation
from huggingface_hub import HfApi # To transact with huggingface.co
import sys # To quit the script
import datetime # get current year
from time import time, sleep # To time the script
from datetime import datetime # To get the current date and time
# Start timer
start = time()
################################################################################
# Configuration
# Get current year
year = str(datetime.now().year)
# Flag to force download and conversion even if files already exist
FORCE = True
# Flag to embed the data locally, otherwise it will use mxbai api to embed
LOCAL = False
# Flag to upload the data to the Hugging Face Hub
UPLOAD = True
# Flag to binarise the data
BINARY = True
# Print the configuration
print(f'Configuration:')
print(f'Year: {year}')
print(f'Force: {FORCE}')
print(f'Local: {LOCAL}')
print(f'Upload: {UPLOAD}')
print(f'Binary: {BINARY}')
########################################
# Model to use for embedding
model_name = "mixedbread-ai/mxbai-embed-large-v1"
# Number of cores to use for multiprocessing
num_cores = cpu_count()-1
# Setup transaction details
repo_id = "bluuebunny/arxiv_abstract_embedding_mxbai_large_v1_milvus"
# Import secrets
config = dotenv_values(".env")
def is_running_in_huggingface_space():
return "SPACE_ID" in os.environ
################################################################################
# Download the dataset
# Dataset name
dataset_path = 'Cornell-University/arxiv'
# Download folder
download_folder = 'data'
# Data file path
download_file = f'{download_folder}/arxiv-metadata-oai-snapshot.json'
## Download the dataset if it doesn't exist
if not os.path.exists(download_file) or FORCE:
print(f'Downloading {download_file}, if it exists it will be overwritten')
print('Set FORCE to False to skip download if file already exists')
subprocess.run(['kaggle', 'datasets', 'download', '--dataset', dataset_path, '--path', download_folder, '--unzip'])
print(f'Downloaded {download_file}')
else:
print(f'{download_file} already exists, skipping download')
print('Set FORCE = True to force download')
################################################################################
# Filter by year and convert to parquet
# https://huggingface.co/docs/datasets/en/about_arrow#memory-mapping
# Load metadata
print(f"Loading json metadata")
dataset = load_dataset("json", data_files= str(f"{download_file}"))
# Split metadata by year
# Convert to pandas
print(f"Converting metadata into pandas")
arxiv_metadata_all = dataset['train'].to_pandas()
########################################
# Function to extract year from arxiv id
# https://info.arxiv.org/help/arxiv_identifier.html
# Function to extract Month and year of publication using arxiv ID
def extract_month_year(arxiv_id, what='month'):
# Identify the relevant YYMM part based on the arXiv ID format
yymm = arxiv_id.split('/')[-1][:4] if '/' in arxiv_id else arxiv_id.split('.')[0]
# Convert the year-month string to a datetime object
date = datetime.strptime(yymm, '%y%m')
# Return the desired part based on the input parameter
return date.strftime('%B') if what == 'month' else date.strftime('%Y')
########################################
# Add year to metadata
print(f"Adding year to metadata")
arxiv_metadata_all['year'] = arxiv_metadata_all['id'].progress_apply(extract_month_year, what='year')
# Filter by year
print(f"Filtering metadata by year: {year}")
arxiv_metadata_split = arxiv_metadata_all[arxiv_metadata_all['year'] == year]
################################################################################
# Load Model
if LOCAL:
print(f"Setting up local embedding model")
print("To use mxbai API, set LOCAL = False")
# Make the app device agnostic
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Load a pretrained Sentence Transformer model and move it to the appropriate device
print(f"Loading model {model_name} to device: {device}")
model = SentenceTransformer(model_name)
model = model.to(device)
else:
print("Setting up mxbai API client")
print("To use local resources, set LOCAL = True")
# Setup mxbai
if is_running_in_huggingface_space():
mxbai_api_key = os.getenv("MXBAI_API_KEY")
else:
mxbai_api_key = config["MXBAI_API_KEY"]
mxbai = MixedbreadAI(api_key=mxbai_api_key)
########################################
# Function that does the embedding
def embed(input_text):
if LOCAL:
# Calculate embeddings by calling model.encode(), specifying the device
embedding = model.encode(input_text, device=device, precision="float32")
# Enforce 32-bit float precision
embedding = np.array(embedding, dtype=np.float32)
else:
# Avoid rate limit from api
sleep(0.2)
# Calculate embeddings by calling mxbai.embeddings()
result = mxbai.embeddings(
model='mixedbread-ai/mxbai-embed-large-v1',
input=input_text,
normalized=True,
encoding_format='float',
truncation_strategy='end'
)
# Enforce 32-bit float precision
embedding = np.array(result.data[0].embedding, dtype=np.float32)
return embedding
########################################
################################################################################
# Gather preexisting embeddings
# Subfolder in the repo of the dataset where the file is stored
folder_in_repo = "data"
allow_patterns = f"{folder_in_repo}/{year}.parquet"
# Where to store the local copy of the dataset
local_dir = repo_id
# Set repo type
repo_type = "dataset"
# Create local directory
os.makedirs(local_dir, exist_ok=True)
# Download the repo
snapshot_download(repo_id=repo_id, repo_type=repo_type, local_dir=local_dir, allow_patterns=allow_patterns)
try:
# Gather previous embed file
previous_embed = f'{local_dir}/{folder_in_repo}/{year}.parquet'
# Load previous_embed
print(f"Loading previously embedded file: {previous_embed}")
previous_embeddings = pd.read_parquet(previous_embed)
except Exception as e:
print(f"Errored out with: {e}")
print(f"No previous embeddings found for year: {year}")
print("Creating new embeddings for all papers")
previous_embeddings = pd.DataFrame(columns=['id', 'vector', 'title', 'abstract', 'authors', 'categories', 'month', 'year', 'url'])
########################################
# Embed the new abstracts
# Find papers that are not in the previous embeddings
new_papers = arxiv_metadata_split[~arxiv_metadata_split['id'].isin(previous_embeddings['id'])]
# Drop duplicates based on the 'id' column
new_papers = new_papers.drop_duplicates(subset='id', keep='last', ignore_index=True)
# Number of new papers
num_new_papers = len(new_papers)
# What if there are no new papers?
if num_new_papers == 0:
print(f"No new papers found for year: {year}")
print("Exiting")
sys.exit()
# Create a column for embeddings
print(f"Creating new embeddings for: {num_new_papers} entries")
new_papers["vector"] = new_papers["abstract"].progress_apply(embed)
####################
print("Adding url and month columns")
# Add URL column
new_papers['url'] = 'https://arxiv.org/abs/' + new_papers['id']
# Add month column
new_papers['month'] = new_papers['id'].progress_apply(extract_month_year, what='month')
####################
print("Removing newline characters from title, authors, categories, abstract")
# Remove newline characters from authors, title, abstract and categories columns
new_papers['title'] = new_papers['title'].astype(str).str.replace('\n', ' ', regex=False)
new_papers['authors'] = new_papers['authors'].astype(str).str.replace('\n', ' ', regex=False)
new_papers['categories'] = new_papers['categories'].astype(str).str.replace('\n', ' ', regex=False)
new_papers['abstract'] = new_papers['abstract'].astype(str).str.replace('\n', ' ', regex=False)
####################
print("Trimming title, authors, categories, abstract")
# Trim title to 512 characters
new_papers['title'] = new_papers['title'].progress_apply(lambda x: x[:508] + '...' if len(x) > 512 else x)
# Trim categories to 128 characters
new_papers['categories'] = new_papers['categories'].progress_apply(lambda x: x[:124] + '...' if len(x) > 128 else x)
# Trim authors to 128 characters
new_papers['authors'] = new_papers['authors'].progress_apply(lambda x: x[:124] + '...' if len(x) > 128 else x)
# Trim abstract to 3072 characters
new_papers['abstract'] = new_papers['abstract'].progress_apply(lambda x: x[:3068] + '...' if len(x) > 3072 else x)
####################
print("Concatenating previouly embedded dataframe with new embeddings")
# Selecting id, vector and $meta to retain
selected_columns = ['id', 'vector', 'title', 'abstract', 'authors', 'categories', 'month', 'year', 'url']
# Merge previous embeddings and new embeddings
new_embeddings = pd.concat([previous_embeddings, new_papers[selected_columns]])
# Create embed folder
embed_folder = f"{year}-diff-embed"
os.makedirs(embed_folder, exist_ok=True)
# Save the embedded file
embed_filename = f'{embed_folder}/{year}.parquet'
print(f"Saving newly embedded dataframe to: {embed_filename}")
# Keeping index=False to avoid saving the index column as a separate column in the parquet file
# This keeps milvus from throwing an error when importing the parquet file
new_embeddings.to_parquet(embed_filename, index=False)
################################################################################
# Upload the new embeddings to the repo
if UPLOAD:
print(f"Uploading new embeddings to: {repo_id}")
# Setup Hugging Face API
if is_running_in_huggingface_space():
access_token = os.getenv("HF_API_KEY")
else:
access_token = config["HF_API_KEY"]
api = HfApi(token=access_token)
# Upload all files within the folder to the specified repository
api.upload_folder(repo_id=repo_id, folder_path=embed_folder, path_in_repo=folder_in_repo, repo_type="dataset")
print(f"Upload complete for year: {year}")
else:
print("Not uploading new embeddings to the repo")
print("To upload new embeddings, set UPLOAD to True")
################################################################################
# Binarise the data
if BINARY:
print(f"Binarising the data for year: {year}")
print("Set BINARY = False to not binarise the embeddings")
# Function to convert dense vector to binary vector
def dense_to_binary(dense_vector):
return np.packbits(np.where(dense_vector >= 0, 1, 0)).tobytes()
# Create a folder to store binary embeddings
binary_folder = f"{year}-binary-embed"
os.makedirs(binary_folder, exist_ok=True)
# Convert the dense vectors to binary vectors
new_embeddings['vector'] = new_embeddings['vector'].progress_apply(dense_to_binary)
# Save the binary embeddings to a parquet file
new_embeddings.to_parquet(f'{binary_folder}/{year}.parquet', index=False)
if BINARY and UPLOAD:
# Setup transaction details
repo_id = "bluuebunny/arxiv_abstract_embedding_mxbai_large_v1_milvus_binary"
repo_type = "dataset"
api.create_repo(repo_id=repo_id, repo_type=repo_type, exist_ok=True)
# Subfolder in the repo of the dataset where the file is stored
folder_in_repo = "data"
print(f"Uploading binary embeddings to {repo_id} from folder {binary_folder}")
# Upload all files within the folder to the specified repository
api.upload_folder(repo_id=repo_id, folder_path=binary_folder, path_in_repo=folder_in_repo, repo_type=repo_type)
print("Upload complete")
else:
print("Not uploading Binary embeddings to the repo")
print("To upload embeddings, set UPLOAD and BINARY both to True")
################################################################################
# Track time
end = time()
# Calculate and show time taken
print(f"Time taken: {end - start} seconds")
print("Done!") |