Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import sys | |
| import argparse | |
| from huggingface_hub import hf_hub_download | |
| from pathlib import Path | |
| def download_scenedino_checkpoint(model_name): | |
| print("----------------------- Downloading pretrained model -----------------------") | |
| model_configs = { | |
| "ssc-kitti-360-dino": { | |
| "model-dir": "seg-best-dino" | |
| }, | |
| "ssc-kitti-360-dino-orb-slam": { | |
| "model-dir": "seg-best-dino-orb-slam" | |
| }, | |
| "ssc-kitti-360-dinov2": { | |
| "model-dir": "seg-best-dinov2" | |
| } | |
| } | |
| repo_id = "jev-aleks/SceneDINO" | |
| checkpoint_filename = "checkpoint.pt" | |
| config_filename = "training_config.yaml" | |
| if model_name not in model_configs: | |
| raise ValueError(f"Unknown model: {model_name}. Possible options: {', '.join(model_configs.keys())}") | |
| config = model_configs[model_name] | |
| output_dir = Path("out/scenedino-pretrained") | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| checkpoint_filename = Path(config["model-dir"]) / checkpoint_filename | |
| config_filename = Path(config["model-dir"]) / config_filename | |
| checkpoint_path = output_dir / checkpoint_filename | |
| config_path = output_dir / config_filename | |
| print(f"Operating in \"{os.getcwd()}\".") | |
| print(f"Creating directories: {output_dir}") | |
| # Download checkpoint | |
| print(f"Downloading checkpoint from HF repo \"{repo_id}\" to \"{checkpoint_path}\".") | |
| hf_hub_download( | |
| repo_id=repo_id, | |
| filename=str(checkpoint_filename), | |
| local_dir=str(output_dir), | |
| ) | |
| # Download config | |
| print(f"Downloading config from HF repo \"{repo_id}\" to \"{config_path}\".") | |
| hf_hub_download( | |
| repo_id=repo_id, | |
| filename=str(config_filename), | |
| local_dir=str(output_dir), | |
| ) | |
| print("Download completed successfully!") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Download pretrained models from Hugging Face Hub") | |
| parser.add_argument("model", help="Model name to download") | |
| args = parser.parse_args() | |
| download_scenedino_checkpoint(args.model) | |
| if __name__ == "__main__": | |
| main() |