## Download the arXiv metadata from Kaggle
## https://www.kaggle.com/datasets/Cornell-University/arxiv

## Requires the Kaggle API to be installed
## Using subprocess to run the Kaggle CLI commands instead of Kaggle API
## As it allows for anonymous downloads without needing to sign in
import subprocess
from datasets import load_dataset # To load dataset without breaking ram
from multiprocessing import cpu_count # To get the number of cores
from sentence_transformers import SentenceTransformer # For embedding the text
import torch # For gpu 
import pandas as pd # Data manipulation
from huggingface_hub import snapshot_download # Download previous embeddings
import os # Folder and file creation
from tqdm import tqdm # Progress bar
tqdm.pandas() # Progress bar for pandas
from mixedbread_ai.client import MixedbreadAI # For embedding the text
from dotenv import dotenv_values # To load environment variables
import numpy as np # For array manipulation
from huggingface_hub import HfApi # To transact with huggingface.co
import sys # To quit the script
import datetime # get current year
from time import time, sleep # To time the script
from datetime import datetime # To get the current date and time

# Start timer
start = time()

################################################################################
# Configuration

# Get current year
year = str(datetime.now().year)

# Flag to force download and conversion even if files already exist
FORCE = True

# Flag to embed the data locally, otherwise it will use mxbai api to embed
LOCAL = False

# Flag to upload the data to the Hugging Face Hub
UPLOAD = True

# Flag to binarise the data
BINARY = True

# Print the configuration
print(f'Configuration:')
print(f'Year: {year}')
print(f'Force: {FORCE}')
print(f'Local: {LOCAL}')
print(f'Upload: {UPLOAD}')
print(f'Binary: {BINARY}')

########################################

# Model to use for embedding
model_name = "mixedbread-ai/mxbai-embed-large-v1"

# Number of cores to use for multiprocessing
num_cores = cpu_count()-1

# Setup transaction details
repo_id = "bluuebunny/arxiv_abstract_embedding_mxbai_large_v1_milvus"

# Import secrets
config = dotenv_values(".env")

def is_running_in_huggingface_space():
    return "SPACE_ID" in os.environ

################################################################################
# Download the dataset

# Dataset name
dataset_path = 'Cornell-University/arxiv'

# Download folder
download_folder = 'data'

# Data file path
download_file = f'{download_folder}/arxiv-metadata-oai-snapshot.json'

## Download the dataset if it doesn't exist
if not os.path.exists(download_file) or FORCE:

    print(f'Downloading {download_file}, if it exists it will be overwritten')
    print('Set FORCE to False to skip download if file already exists')

    subprocess.run(['kaggle', 'datasets', 'download', '--dataset', dataset_path, '--path', download_folder, '--unzip'])
    
    print(f'Downloaded {download_file}')

else:

    print(f'{download_file} already exists, skipping download')
    print('Set FORCE = True to force download')

################################################################################
# Filter by year and convert to parquet

# https://huggingface.co/docs/datasets/en/about_arrow#memory-mapping
# Load metadata
print(f"Loading json metadata")
dataset = load_dataset("json", data_files= str(f"{download_file}"))

# Split metadata by year
# Convert to pandas
print(f"Converting metadata into pandas")
arxiv_metadata_all = dataset['train'].to_pandas()

########################################
# Function to extract year from arxiv id
# https://info.arxiv.org/help/arxiv_identifier.html
# Function to extract Month and year of publication using arxiv ID
def extract_month_year(arxiv_id, what='month'):
    # Identify the relevant YYMM part based on the arXiv ID format
    yymm = arxiv_id.split('/')[-1][:4] if '/' in arxiv_id else arxiv_id.split('.')[0]
    
    # Convert the year-month string to a datetime object
    date = datetime.strptime(yymm, '%y%m')
    
    # Return the desired part based on the input parameter
    return date.strftime('%B') if what == 'month' else date.strftime('%Y')
########################################

# Add year to metadata
print(f"Adding year to metadata")
arxiv_metadata_all['year'] =  arxiv_metadata_all['id'].progress_apply(extract_month_year, what='year')

# Filter by year
print(f"Filtering metadata by year: {year}")
arxiv_metadata_split = arxiv_metadata_all[arxiv_metadata_all['year'] == year]

################################################################################
# Load Model

if LOCAL:

    print(f"Setting up local embedding model")
    print("To use mxbai API, set LOCAL = False")

    # Make the app device agnostic
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

    # Load a pretrained Sentence Transformer model and move it to the appropriate device
    print(f"Loading model {model_name} to device: {device}")
    model = SentenceTransformer(model_name)
    model = model.to(device)
else:
    print("Setting up mxbai API client")
    print("To use local resources, set LOCAL = True")

    # Setup mxbai
    if is_running_in_huggingface_space():
        mxbai_api_key = os.getenv("MXBAI_API_KEY")
    else:
        mxbai_api_key = config["MXBAI_API_KEY"]

    mxbai = MixedbreadAI(api_key=mxbai_api_key)

########################################
# Function that does the embedding
def embed(input_text):
    
    if LOCAL:

        # Calculate embeddings by calling model.encode(), specifying the device
        embedding = model.encode(input_text, device=device, precision="float32")

        # Enforce 32-bit float precision
        embedding = np.array(embedding, dtype=np.float32)

    else:
        
        # Avoid rate limit from api
        sleep(0.2)

        # Calculate embeddings by calling mxbai.embeddings()
        result = mxbai.embeddings(
        model='mixedbread-ai/mxbai-embed-large-v1',
        input=input_text,
        normalized=True,
        encoding_format='float',
        truncation_strategy='end'
        )

        # Enforce 32-bit float precision
        embedding = np.array(result.data[0].embedding, dtype=np.float32)

    return embedding
########################################

################################################################################
# Gather preexisting embeddings

# Subfolder in the repo of the dataset where the file is stored
folder_in_repo = "data"
allow_patterns = f"{folder_in_repo}/{year}.parquet"

# Where to store the local copy of the dataset
local_dir = repo_id

# Set repo type
repo_type = "dataset"

# Create local directory
os.makedirs(local_dir, exist_ok=True)

# Download the repo
snapshot_download(repo_id=repo_id, repo_type=repo_type, local_dir=local_dir, allow_patterns=allow_patterns)

try:

    # Gather previous embed file
    previous_embed = f'{local_dir}/{folder_in_repo}/{year}.parquet'

    # Load previous_embed
    print(f"Loading previously embedded file: {previous_embed}")   
    previous_embeddings = pd.read_parquet(previous_embed)

except Exception as e:
    print(f"Errored out with: {e}")
    print(f"No previous embeddings found for year: {year}")
    print("Creating new embeddings for all papers")
    previous_embeddings = pd.DataFrame(columns=['id', 'vector', 'title', 'abstract', 'authors', 'categories', 'month', 'year', 'url'])

########################################
# Embed the new abstracts

# Find papers that are not in the previous embeddings
new_papers = arxiv_metadata_split[~arxiv_metadata_split['id'].isin(previous_embeddings['id'])]

# Drop duplicates based on the 'id' column
new_papers = new_papers.drop_duplicates(subset='id', keep='last', ignore_index=True)

# Number of new papers
num_new_papers = len(new_papers)

# What if there are no new papers?
if num_new_papers == 0:
    print(f"No new papers found for year: {year}")
    print("Exiting")
    sys.exit()

# Create a column for embeddings
print(f"Creating new embeddings for: {num_new_papers} entries")
new_papers["vector"] = new_papers["abstract"].progress_apply(embed)

####################
print("Adding url and month columns")

# Add URL column
new_papers['url'] = 'https://arxiv.org/abs/' + new_papers['id']

# Add month column
new_papers['month'] = new_papers['id'].progress_apply(extract_month_year, what='month')

####################
print("Removing newline characters from title, authors, categories, abstract")

# Remove newline characters from authors, title, abstract and categories columns
new_papers['title'] = new_papers['title'].astype(str).str.replace('\n', ' ', regex=False)

new_papers['authors'] = new_papers['authors'].astype(str).str.replace('\n', ' ', regex=False)

new_papers['categories'] = new_papers['categories'].astype(str).str.replace('\n', ' ', regex=False)

new_papers['abstract'] = new_papers['abstract'].astype(str).str.replace('\n', ' ', regex=False)

####################
print("Trimming title, authors, categories, abstract")

# Trim title to 512 characters
new_papers['title'] = new_papers['title'].progress_apply(lambda x: x[:508] + '...' if len(x) > 512 else x)

# Trim categories to 128 characters
new_papers['categories'] = new_papers['categories'].progress_apply(lambda x: x[:124] + '...' if len(x) > 128 else x)

# Trim authors to 128 characters
new_papers['authors'] = new_papers['authors'].progress_apply(lambda x: x[:124] + '...' if len(x) > 128 else x)

# Trim abstract to 3072 characters
new_papers['abstract'] = new_papers['abstract'].progress_apply(lambda x: x[:3068] + '...' if len(x) > 3072 else x)

####################
print("Concatenating previouly embedded dataframe with new embeddings")

# Selecting id, vector and $meta to retain
selected_columns = ['id', 'vector', 'title', 'abstract', 'authors', 'categories', 'month', 'year', 'url']

# Merge previous embeddings and new embeddings
new_embeddings = pd.concat([previous_embeddings, new_papers[selected_columns]])

# Create embed folder
embed_folder = f"{year}-diff-embed"
os.makedirs(embed_folder, exist_ok=True)

# Save the embedded file
embed_filename = f'{embed_folder}/{year}.parquet'
print(f"Saving newly embedded dataframe to: {embed_filename}")
# Keeping index=False to avoid saving the index column as a separate column in the parquet file
# This keeps milvus from throwing an error when importing the parquet file
new_embeddings.to_parquet(embed_filename, index=False)

################################################################################

# Upload the new embeddings to the repo
if UPLOAD:

    print(f"Uploading new embeddings to: {repo_id}")

    # Setup Hugging Face API
    if is_running_in_huggingface_space():
        access_token = os.getenv("HF_API_KEY")
    else:
        access_token =  config["HF_API_KEY"]

    api = HfApi(token=access_token)

    # Upload all files within the folder to the specified repository
    api.upload_folder(repo_id=repo_id, folder_path=embed_folder, path_in_repo=folder_in_repo, repo_type="dataset")

    print(f"Upload complete for year: {year}")

else:
    print("Not uploading new embeddings to the repo")
    print("To upload new embeddings, set UPLOAD to True")
################################################################################

# Binarise the data
if BINARY:

    print(f"Binarising the data for year: {year}")
    print("Set BINARY = False to not binarise the embeddings")

    # Function to convert dense vector to binary vector
    def dense_to_binary(dense_vector):
        return np.packbits(np.where(dense_vector >= 0, 1, 0)).tobytes()

    # Create a folder to store binary embeddings
    binary_folder = f"{year}-binary-embed"
    os.makedirs(binary_folder, exist_ok=True)

    # Convert the dense vectors to binary vectors
    new_embeddings['vector'] = new_embeddings['vector'].progress_apply(dense_to_binary)

    # Save the binary embeddings to a parquet file
    new_embeddings.to_parquet(f'{binary_folder}/{year}.parquet', index=False)

if BINARY and UPLOAD:

    # Setup transaction details
    repo_id = "bluuebunny/arxiv_abstract_embedding_mxbai_large_v1_milvus_binary"
    repo_type = "dataset"

    api.create_repo(repo_id=repo_id, repo_type=repo_type, exist_ok=True)

    # Subfolder in the repo of the dataset where the file is stored
    folder_in_repo = "data"

    print(f"Uploading binary embeddings to {repo_id} from folder {binary_folder}")

    # Upload all files within the folder to the specified repository
    api.upload_folder(repo_id=repo_id, folder_path=binary_folder, path_in_repo=folder_in_repo, repo_type=repo_type)

    print("Upload complete")

else:
    print("Not uploading Binary embeddings to the repo")
    print("To upload embeddings, set UPLOAD and BINARY both to True")

################################################################################

# Track time
end = time()

# Calculate and show time taken
print(f"Time taken: {end - start} seconds")

print("Done!")