Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| from segment_anything import build_sam, build_sam_vit_b, build_sam_vit_l | |
| from segment_anything.utils.onnx import SamOnnxModel | |
| import argparse | |
| import warnings | |
| try: | |
| import onnxruntime # type: ignore | |
| onnxruntime_exists = True | |
| except ImportError: | |
| onnxruntime_exists = False | |
| parser = argparse.ArgumentParser( | |
| description="Export the SAM prompt encoder and mask decoder to an ONNX model." | |
| ) | |
| parser.add_argument( | |
| "--checkpoint", type=str, required=True, help="The path to the SAM model checkpoint." | |
| ) | |
| parser.add_argument( | |
| "--output", type=str, required=True, help="The filename to save the ONNX model to." | |
| ) | |
| parser.add_argument( | |
| "--model-type", | |
| type=str, | |
| default="default", | |
| help="In ['default', 'vit_b', 'vit_l']. Which type of SAM model to export.", | |
| ) | |
| parser.add_argument( | |
| "--return-single-mask", | |
| action="store_true", | |
| help=( | |
| "If true, the exported ONNX model will only return the best mask, " | |
| "instead of returning multiple masks. For high resolution images " | |
| "this can improve runtime when upscaling masks is expensive." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--opset", | |
| type=int, | |
| default=17, | |
| help="The ONNX opset version to use. Must be >=11", | |
| ) | |
| parser.add_argument( | |
| "--quantize-out", | |
| type=str, | |
| default=None, | |
| help=( | |
| "If set, will quantize the model and save it with this name. " | |
| "Quantization is performed with quantize_dynamic from onnxruntime.quantization.quantize." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--gelu-approximate", | |
| action="store_true", | |
| help=( | |
| "Replace GELU operations with approximations using tanh. Useful " | |
| "for some runtimes that have slow or unimplemented erf ops, used in GELU." | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--use-stability-score", | |
| action="store_true", | |
| help=( | |
| "Replaces the model's predicted mask quality score with the stability " | |
| "score calculated on the low resolution masks using an offset of 1.0. " | |
| ), | |
| ) | |
| parser.add_argument( | |
| "--return-extra-metrics", | |
| action="store_true", | |
| help=( | |
| "The model will return five results: (masks, scores, stability_scores, " | |
| "areas, low_res_logits) instead of the usual three. This can be " | |
| "significantly slower for high resolution outputs." | |
| ), | |
| ) | |
| def run_export( | |
| model_type: str, | |
| checkpoint: str, | |
| output: str, | |
| opset: int, | |
| return_single_mask: bool, | |
| gelu_approximate: bool = False, | |
| use_stability_score: bool = False, | |
| return_extra_metrics=False, | |
| ): | |
| print("Loading model...") | |
| if model_type == "vit_b": | |
| sam = build_sam_vit_b(checkpoint) | |
| elif model_type == "vit_l": | |
| sam = build_sam_vit_l(checkpoint) | |
| else: | |
| sam = build_sam(checkpoint) | |
| onnx_model = SamOnnxModel( | |
| model=sam, | |
| return_single_mask=return_single_mask, | |
| use_stability_score=use_stability_score, | |
| return_extra_metrics=return_extra_metrics, | |
| ) | |
| if gelu_approximate: | |
| for n, m in onnx_model.named_modules(): | |
| if isinstance(m, torch.nn.GELU): | |
| m.approximate = "tanh" | |
| dynamic_axes = { | |
| "point_coords": {1: "num_points"}, | |
| "point_labels": {1: "num_points"}, | |
| } | |
| embed_dim = sam.prompt_encoder.embed_dim | |
| embed_size = sam.prompt_encoder.image_embedding_size | |
| mask_input_size = [4 * x for x in embed_size] | |
| dummy_inputs = { | |
| "image_embeddings": torch.randn(1, embed_dim, *embed_size, dtype=torch.float), | |
| "point_coords": torch.randint(low=0, high=1024, size=(1, 5, 2), dtype=torch.float), | |
| "point_labels": torch.randint(low=0, high=4, size=(1, 5), dtype=torch.float), | |
| "mask_input": torch.randn(1, 1, *mask_input_size, dtype=torch.float), | |
| "has_mask_input": torch.tensor([1], dtype=torch.float), | |
| "orig_im_size": torch.tensor([1500, 2250], dtype=torch.float), | |
| } | |
| _ = onnx_model(**dummy_inputs) | |
| output_names = ["masks", "iou_predictions", "low_res_masks"] | |
| with warnings.catch_warnings(): | |
| warnings.filterwarnings("ignore", category=torch.jit.TracerWarning) | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| with open(output, "wb") as f: | |
| print(f"Exporing onnx model to {output}...") | |
| torch.onnx.export( | |
| onnx_model, | |
| tuple(dummy_inputs.values()), | |
| f, | |
| export_params=True, | |
| verbose=False, | |
| opset_version=opset, | |
| do_constant_folding=True, | |
| input_names=list(dummy_inputs.keys()), | |
| output_names=output_names, | |
| dynamic_axes=dynamic_axes, | |
| ) | |
| if onnxruntime_exists: | |
| ort_inputs = {k: to_numpy(v) for k, v in dummy_inputs.items()} | |
| ort_session = onnxruntime.InferenceSession(output) | |
| _ = ort_session.run(None, ort_inputs) | |
| print("Model has successfully been run with ONNXRuntime.") | |
| def to_numpy(tensor): | |
| return tensor.cpu().numpy() | |
| if __name__ == "__main__": | |
| args = parser.parse_args() | |
| run_export( | |
| model_type=args.model_type, | |
| checkpoint=args.checkpoint, | |
| output=args.output, | |
| opset=args.opset, | |
| return_single_mask=args.return_single_mask, | |
| gelu_approximate=args.gelu_approximate, | |
| use_stability_score=args.use_stability_score, | |
| return_extra_metrics=args.return_extra_metrics, | |
| ) | |
| if args.quantize_out is not None: | |
| assert onnxruntime_exists, "onnxruntime is required to quantize the model." | |
| from onnxruntime.quantization import QuantType # type: ignore | |
| from onnxruntime.quantization.quantize import quantize_dynamic # type: ignore | |
| print(f"Quantizing model and writing to {args.quantize_out}...") | |
| quantize_dynamic( | |
| model_input=args.output, | |
| model_output=args.quantize_out, | |
| optimize_model=True, | |
| per_channel=False, | |
| reduce_range=False, | |
| weight_type=QuantType.QUInt8, | |
| ) | |
| print("Done!") | |