import os
import argparse
from huggingface_hub import login, snapshot_download
from typing import Literal
from sentence_transformers import SentenceTransformer
from pydantic import BaseModel

DATA_PATH = os.getenv('DATA_PATH')
HUGGINGFACE_ACCESS_TOKEN = os.getenv('HF_TOKEN')

login(token=HUGGINGFACE_ACCESS_TOKEN)

class ModelConfig(BaseModel):
    model_id : str
    mode: Literal['snapshot', 'model'] = 'model'

    class Config:
        protected_namespaces = ()

def download(config: ModelConfig):
    try:

        if config.mode == 'snapshot':
            snapshot_download(
                config.model_id,
                revision='main',
                ignore_patterns=['*.git*', '*README.md'],
                local_dir=os.path.join(DATA_PATH, config.model_id)
            )
        else:
            model = SentenceTransformer(
                config.model_id,
                trust_remote_code=True,
            )
            model.save(os.path.join(DATA_PATH, config.model_id))
    except Exception as e:
        raise e
    
def run():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '-i', '--input',
        help='model id to download',
        required=True,
    )
    parser.add_argument(
        '-m', '--mode',
        help='mode to download',
        default='model',
    )

    args = parser.parse_args()
    config = ModelConfig(
        model_id=args.input,
        mode=args.mode
    )


    download(config)

if __name__ == '__main__':
    run()