Spaces:
Sleeping
Sleeping
| # Copyright (C) 2021-2024, Mindee. | |
| # This program is licensed under the Apache License 2.0. | |
| # See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details. | |
| # Inspired by: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/hub.py | |
| import json | |
| import logging | |
| import os | |
| import subprocess | |
| import textwrap | |
| from pathlib import Path | |
| from typing import Any | |
| from huggingface_hub import ( | |
| HfApi, | |
| Repository, | |
| get_token, | |
| get_token_permission, | |
| hf_hub_download, | |
| login, | |
| snapshot_download, | |
| ) | |
| from doctr import models | |
| from doctr.file_utils import is_tf_available, is_torch_available | |
| if is_torch_available(): | |
| import torch | |
| __all__ = ["login_to_hub", "push_to_hf_hub", "from_hub", "_save_model_and_config_for_hf_hub"] | |
| AVAILABLE_ARCHS = { | |
| "classification": models.classification.zoo.ARCHS, | |
| "detection": models.detection.zoo.ARCHS, | |
| "recognition": models.recognition.zoo.ARCHS, | |
| } | |
| def login_to_hub() -> None: # pragma: no cover | |
| """Login to huggingface hub""" | |
| access_token = get_token() | |
| if access_token is not None and get_token_permission(access_token): | |
| logging.info("Huggingface Hub token found and valid") | |
| login(token=access_token, write_permission=True) | |
| else: | |
| login() | |
| # check if git lfs is installed | |
| try: | |
| subprocess.call(["git", "lfs", "version"]) | |
| except FileNotFoundError: | |
| raise OSError( | |
| "Looks like you do not have git-lfs installed, please install. \ | |
| You can install from https://git-lfs.github.com/. \ | |
| Then run `git lfs install` (you only have to do this once)." | |
| ) | |
| def _save_model_and_config_for_hf_hub(model: Any, save_dir: str, arch: str, task: str) -> None: | |
| """Save model and config to disk for pushing to huggingface hub | |
| Args: | |
| ---- | |
| model: TF or PyTorch model to be saved | |
| save_dir: directory to save model and config | |
| arch: architecture name | |
| task: task name | |
| """ | |
| save_directory = Path(save_dir) | |
| if is_torch_available(): | |
| weights_path = save_directory / "pytorch_model.bin" | |
| torch.save(model.state_dict(), weights_path) | |
| elif is_tf_available(): | |
| weights_path = save_directory / "tf_model" / "weights" | |
| model.save_weights(str(weights_path)) | |
| config_path = save_directory / "config.json" | |
| # add model configuration | |
| model_config = model.cfg | |
| model_config["arch"] = arch | |
| model_config["task"] = task | |
| with config_path.open("w") as f: | |
| json.dump(model_config, f, indent=2, ensure_ascii=False) | |
| def push_to_hf_hub(model: Any, model_name: str, task: str, **kwargs) -> None: # pragma: no cover | |
| """Save model and its configuration on HF hub | |
| >>> from doctr.models import login_to_hub, push_to_hf_hub | |
| >>> from doctr.models.recognition import crnn_mobilenet_v3_small | |
| >>> login_to_hub() | |
| >>> model = crnn_mobilenet_v3_small(pretrained=True) | |
| >>> push_to_hf_hub(model, 'my-model', 'recognition', arch='crnn_mobilenet_v3_small') | |
| Args: | |
| ---- | |
| model: TF or PyTorch model to be saved | |
| model_name: name of the model which is also the repository name | |
| task: task name | |
| **kwargs: keyword arguments for push_to_hf_hub | |
| """ | |
| run_config = kwargs.get("run_config", None) | |
| arch = kwargs.get("arch", None) | |
| if run_config is None and arch is None: | |
| raise ValueError("run_config or arch must be specified") | |
| if task not in ["classification", "detection", "recognition"]: | |
| raise ValueError("task must be one of classification, detection, recognition") | |
| # default readme | |
| readme = textwrap.dedent( | |
| f""" | |
| --- | |
| language: en | |
| --- | |
| <p align="center"> | |
| <img src="https://doctr-static.mindee.com/models?id=v0.3.1/Logo_doctr.gif&src=0" width="60%"> | |
| </p> | |
| **Optical Character Recognition made seamless & accessible to anyone, powered by TensorFlow 2 & PyTorch** | |
| ## Task: {task} | |
| https://github.com/mindee/doctr | |
| ### Example usage: | |
| ```python | |
| >>> from doctr.io import DocumentFile | |
| >>> from doctr.models import ocr_predictor, from_hub | |
| >>> img = DocumentFile.from_images(['<image_path>']) | |
| >>> # Load your model from the hub | |
| >>> model = from_hub('mindee/my-model') | |
| >>> # Pass it to the predictor | |
| >>> # If your model is a recognition model: | |
| >>> predictor = ocr_predictor(det_arch='db_mobilenet_v3_large', | |
| >>> reco_arch=model, | |
| >>> pretrained=True) | |
| >>> # If your model is a detection model: | |
| >>> predictor = ocr_predictor(det_arch=model, | |
| >>> reco_arch='crnn_mobilenet_v3_small', | |
| >>> pretrained=True) | |
| >>> # Get your predictions | |
| >>> res = predictor(img) | |
| ``` | |
| """ | |
| ) | |
| # add run configuration to readme if available | |
| if run_config is not None: | |
| arch = run_config.arch | |
| readme += textwrap.dedent( | |
| f"""### Run Configuration | |
| \n{json.dumps(vars(run_config), indent=2, ensure_ascii=False)}""" | |
| ) | |
| if arch not in AVAILABLE_ARCHS[task]: | |
| raise ValueError( | |
| f"Architecture: {arch} for task: {task} not found.\ | |
| \nAvailable architectures: {AVAILABLE_ARCHS}" | |
| ) | |
| commit_message = f"Add {model_name} model" | |
| local_cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "hub", model_name) | |
| repo_url = HfApi().create_repo(model_name, token=get_token(), exist_ok=False) | |
| repo = Repository(local_dir=local_cache_dir, clone_from=repo_url, use_auth_token=True) | |
| with repo.commit(commit_message): | |
| _save_model_and_config_for_hf_hub(model, repo.local_dir, arch=arch, task=task) | |
| readme_path = Path(repo.local_dir) / "README.md" | |
| readme_path.write_text(readme) | |
| repo.git_push() | |
| def from_hub(repo_id: str, **kwargs: Any): | |
| """Instantiate & load a pretrained model from HF hub. | |
| >>> from doctr.models import from_hub | |
| >>> model = from_hub("mindee/fasterrcnn_mobilenet_v3_large_fpn") | |
| Args: | |
| ---- | |
| repo_id: HuggingFace model hub repo | |
| kwargs: kwargs of `hf_hub_download` or `snapshot_download` | |
| Returns: | |
| ------- | |
| Model loaded with the checkpoint | |
| """ | |
| # Get the config | |
| with open(hf_hub_download(repo_id, filename="config.json", **kwargs), "rb") as f: | |
| cfg = json.load(f) | |
| arch = cfg["arch"] | |
| task = cfg["task"] | |
| cfg.pop("arch") | |
| cfg.pop("task") | |
| if task == "classification": | |
| model = models.classification.__dict__[arch]( | |
| pretrained=False, classes=cfg["classes"], num_classes=cfg["num_classes"] | |
| ) | |
| elif task == "detection": | |
| model = models.detection.__dict__[arch](pretrained=False) | |
| elif task == "recognition": | |
| model = models.recognition.__dict__[arch](pretrained=False, input_shape=cfg["input_shape"], vocab=cfg["vocab"]) | |
| # update model cfg | |
| model.cfg = cfg | |
| # Load checkpoint | |
| if is_torch_available(): | |
| state_dict = torch.load(hf_hub_download(repo_id, filename="pytorch_model.bin", **kwargs), map_location="cpu") | |
| model.load_state_dict(state_dict) | |
| else: # tf | |
| repo_path = snapshot_download(repo_id, **kwargs) | |
| model.load_weights(os.path.join(repo_path, "tf_model", "weights")) | |
| return model | |