ONNX๋ก ๋ด๋ณด๋ด๊ธฐ[[export-to-onnx]]
ํ๋ก๋์ ํ๊ฒฝ์ ๐ค Transformers ๋ชจ๋ธ์ ๋ฐฐํฌํ ๋์๋ ํน์ ๋ฐํ์ ๋ฐ ํ๋์จ์ด ์์ ์ฌ๋ฆฌ๊ณ ์คํํ ์ ์๋๋ก ์ง๋ ฌํ๋ ํ์์ผ๋ก ๋ด๋ณด๋ด๊ธฐ๋ฅผ ๊ถ์ฅํฉ๋๋ค. ์ด ๊ฐ์ด๋์์๋ ๐ค Transformers ๋ชจ๋ธ์ ONNX (Open Neural Network eXchange)๋ก ๋ด๋ณด๋ด๋ ๋ฐฉ๋ฒ์ ์๋ดํฉ๋๋ค.
ONNX๋ ๋ฅ๋ฌ๋ ๋ชจ๋ธ์ ํํํ๊ธฐ ์ํ ๊ณตํต ํ์ผ ํ์๊ณผ ์ฐ์ฐ์๋ค์ ์ ์ํ๋ ๊ฐ๋ฐฉํ ํ์ค์ผ๋ก์จ PyTorch, TensorFlow ๋ฑ ๋ค์ํ ํ๋ ์์ํฌ์์ ์ง์๋ฉ๋๋ค. ๋ชจ๋ธ์ ONNX ํ์์ผ๋ก ๋ด๋ณด๋ด๋ฉด, (๋ณดํต _์ค๊ฐ ํํ (Intermediate Representation; IR)_์ด๋ผ๊ณ ๋ถ๋ฆฌ๋) ๊ณ์ฐ ๊ทธ๋ํ๊ฐ ๊ตฌ์ฑ๋ฉ๋๋ค. ๊ณ์ฐ ๊ทธ๋ํ๋ ์ ๊ฒฝ๋ง์ ํตํด ๋ฐ์ดํฐ๊ฐ ํ๋ฅด๋ ๋ฐฉ์, ์ฆ ์ด๋ค ์ฐ์ฐ์ด ์ด๋ ๋ถ๋ถ์ ์ฌ์ฉ๋์๋์ง๋ฅผ ๋ํ๋ ๋๋ค.
ํ์ค ์ฐ์ฐ ๋ฐ ๋ฐ์ดํฐ ํ์์ ์ฌ์ฉํ์ฌ ๊ทธ๋ํ๋ฅผ ๋ ธ์ถํ๊ธฐ ๋๋ฌธ์ ONNX๋ฅผ ์ฌ์ฉํ๋ฉด ํ๋ ์์ํฌ ๊ฐ ์ ํ์ด ์ฌ์์ง๋๋ค. ์๋ฅผ ๋ค์ด, PyTorch์์ ํ๋ จ๋ ๋ชจ๋ธ์ ONNX ํ์์ผ๋ก ๋ด๋ณด๋ธ ๋ค, TensorFlow์์ ๊ฐ์ ธ์ฌ ์ ์์ต๋๋ค. ๋ฌผ๋ก ๊ทธ ๋ฐ๋๋ ๊ฐ๋ฅํฉ๋๋ค.
๐ค Transformers๋ ๋ชจ๋ธ ์ฒดํฌํฌ์ธํธ๋ฅผ ONNX ๊ทธ๋ํ๋ก ๋ณํํ ์ ์๊ฒ ํด์ฃผ๋ transformers.onnx
ํจํค์ง๋ฅผ ์ ๊ณตํฉ๋๋ค. ์ด๊ฑธ ๊ฐ๋ฅ์ผ ํ๋ ๊ตฌ์ฑ ๊ฐ์ฒด๋ ์ฌ๋ฌ ๋ชจ๋ธ ์ํคํ
์ฒ๋ฅผ ๋์์ผ๋ก ๋ฏธ๋ฆฌ ์ ์๋์ด ์์ผ๋ฉฐ, ๋ค๋ฅธ ์ํคํ
์ฒ๋ก๋ ์ฝ๊ฒ ํ์ฅํ ์ ์๋๋ก ์ค๊ณ๋์์ต๋๋ค.
๐ค Optimum์์ optimum.exporters.onnx
ํจํค์ง๋ฅผ ์ฌ์ฉํ์ฌ ๐ค Transformers ๋ชจ๋ธ์ ๋ด๋ณด๋ผ ์๋ ์์ต๋๋ค.
๋ชจ๋ธ์ ๋ด๋ณด๋ธ ํ ๋ค์๊ณผ ๊ฐ์ด ์ฌ์ฉ๋ ์ ์์ต๋๋ค:
- ์์ํ ๋ฐ ๊ทธ๋ํ ์ต์ ํ์ ๊ฐ์ ๊ธฐ์ ์ ํตํด ์ถ๋ก ์ ์ต์ ํํฉ๋๋ค.
ORTModelForXXX
ํด๋์ค๋ฅผ ํตํด ONNX ๋ฐํ์์์ ์คํํฉ๋๋ค. ์ด ํด๋์ค๋ค์ ๐ค Transformers์์ ์ฌ์ฉํ๋AutoModel
API์ ๋์ผํฉ๋๋ค.- ์ต์ ํ๋ ์ถ๋ก ํ์ดํ๋ผ์ธ ์์ ์คํํฉ๋๋ค. ์ด ํ์ดํ๋ผ์ธ์ ๐ค Transformers์ [
pipeline
] ํจ์์ ๋์ผํ API๋ฅผ ๊ฐ์ต๋๋ค.
์ด๋ฌํ ๊ธฐ๋ฅ์ ๋ชจ๋ ์ดํด๋ณด๋ ค๋ฉด ๐ค Optimum ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ํ์ธํ์ธ์.
๋ฏธ๋ฆฌ ์ ์๋ ๊ตฌ์ฑ์๋ ๋ค์ ์ํคํ ์ฒ๊ฐ ํฌํจ๋ฉ๋๋ค:
- ALBERT
- BART
- BEiT
- BERT
- BigBird
- BigBird-Pegasus
- Blenderbot
- BlenderbotSmall
- BLOOM
- CamemBERT
- Chinese-CLIP
- CLIP
- CodeGen
- Conditional DETR
- ConvBERT
- ConvNeXT
- Data2VecText
- Data2VecVision
- DeBERTa
- DeBERTa-v2
- DeiT
- DETR
- DistilBERT
- EfficientNet
- ELECTRA
- ERNIE
- FlauBERT
- GPT Neo
- GPT-J
- GPT-Sw3
- GroupViT
- I-BERT
- ImageGPT
- LayoutLM
- LayoutLMv3
- LeViT
- Longformer
- LongT5
- M2M100
- Marian
- mBART
- MEGA
- MobileBERT
- MobileNetV1
- MobileNetV2
- MobileViT
- MT5
- OpenAI GPT-2
- OWL-ViT
- Perceiver
- PLBart
- PoolFormer
- RemBERT
- ResNet
- RoBERTa
- RoBERTa-PreLayerNorm
- RoFormer
- SegFormer
- SqueezeBERT
- Swin Transformer
- T5
- Table Transformer
- Vision Encoder decoder
- ViT
- Whisper
- X-MOD
- XLM
- XLM-RoBERTa
- XLM-RoBERTa-XL
- YOLOS
์์ผ๋ก์ ๋ ์น์ ์์๋ ์๋ ๋ด์ฉ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค:
transformers.onnx
ํจํค์ง๋ฅผ ์ฌ์ฉํ์ฌ ์ง์๋๋ ๋ชจ๋ธ ๋ด๋ณด๋ด๊ธฐ- ์ง์๋์ง ์๋ ์ํคํ ์ฒ๋ฅผ ์ํด ์ฌ์ฉ์ ์ ์ ๋ชจ๋ธ ๋ด๋ณด๋ด๊ธฐ
๋ชจ๋ธ์ ONNX๋ก ๋ด๋ณด๋ด๊ธฐ[[exporting-a-model-to-onnx]]
์ด์ ๋ชจ๋ธ์ ๋ด๋ณด๋ผ ๋ optimum.exporters.onnx
๋ฅผ ์ฌ์ฉํ๋๋ก ๊ถ์ฅํฉ๋๋ค. transformers.onnx
์ ๋งค์ฐ ์ ์ฌํ๋ ๊ฑฑ์ ํ์ง ๋ง์ธ์!
๐ค Transformers ๋ชจ๋ธ์ ONNX๋ก ๋ด๋ณด๋ด๋ ค๋ฉด ๋จผ์ ๋ช ๊ฐ์ง ์ถ๊ฐ ์ข ์์ฑ์ ์ค์นํด์ผํฉ๋๋ค:
pip install transformers[onnx]
transformers.onnx
ํจํค์ง๋ ๋ค์๊ณผ ๊ฐ์ด Python ๋ชจ๋๋ก ์ฌ์ฉํ ์ ์์ต๋๋ค:
python -m transformers.onnx --help
usage: Hugging Face Transformers ONNX exporter [-h] -m MODEL [--feature {causal-lm, ...}] [--opset OPSET] [--atol ATOL] output
positional arguments:
output Path indicating where to store generated ONNX model.
optional arguments:
-h, --help show this help message and exit
-m MODEL, --model MODEL
Model ID on huggingface.co or path on disk to load model from.
--feature {causal-lm, ...}
The type of features to export the model with.
--opset OPSET ONNX opset version to export the model with.
--atol ATOL Absolute difference tolerance when validating the model.
๋ค์๊ณผ ๊ฐ์ด ๋ฏธ๋ฆฌ ์ ์๋ ๊ตฌ์ฑ์ ์ฌ์ฉํ์ฌ ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ด๋ณด๋ผ ์ ์์ต๋๋ค:
python -m transformers.onnx --model=distilbert-base-uncased onnx/
๋ค์๊ณผ ๊ฐ์ ๋ก๊ทธ๊ฐ ํ์๋์ด์ผํฉ๋๋ค:
Validating ONNX model...
-[โ] ONNX model output names match reference model ({'last_hidden_state'})
- Validating ONNX Model output "last_hidden_state":
-[โ] (2, 8, 768) matches (2, 8, 768)
-[โ] all values close (atol: 1e-05)
All good, model saved at: onnx/model.onnx
์ด๋ ๊ฒ --model
์ธ์๋ก ์ ์๋ ์ฒดํฌํฌ์ธํธ์ ONNX ๊ทธ๋ํ๋ฅผ ๋ด๋ณด๋
๋๋ค. ์์์์๋ distilbert-base-uncased
์ด์ง๋ง, Hugging Face Hub์์ ๊ฐ์ ธ์๊ฑฐ๋ ๋ก์ปฌ์ ์ ์ฅ๋ ์ฒดํฌํฌ์ธํธ๋ค ๋ชจ๋ ๊ฐ๋ฅํฉ๋๋ค.
๊ฒฐ๊ณผ๋ก ๋์จ model.onnx
ํ์ผ์ ONNX ํ์ค์ ์ง์ํ๋ ๋ค์ํ ๊ฐ์๊ธฐ ์ค ํ๋์์ ์คํํ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, ๋ค์๊ณผ ๊ฐ์ด ONNX Runtime์์ ๋ชจ๋ธ์ ๊ฐ์ ธ์ค๊ณ ์คํํ ์ ์์ต๋๋ค:
>>> from transformers import AutoTokenizer
>>> from onnxruntime import InferenceSession
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
>>> session = InferenceSession("onnx/model.onnx")
>>> # ONNX Runtime expects NumPy arrays as input
>>> inputs = tokenizer("Using DistilBERT with ONNX Runtime!", return_tensors="np")
>>> outputs = session.run(output_names=["last_hidden_state"], input_feed=dict(inputs))
["last_hidden_state"]
์ ๊ฐ์ ํ์ํ ์ถ๋ ฅ ์ด๋ฆ์ ๊ฐ ๋ชจ๋ธ์ ONNX ๊ตฌ์ฑ์ ์ดํด๋ณด๋ฉด ์ป์ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, DistilBERT์ ๊ฒฝ์ฐ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
>>> from transformers.models.distilbert import DistilBertConfig, DistilBertOnnxConfig
>>> config = DistilBertConfig()
>>> onnx_config = DistilBertOnnxConfig(config)
>>> print(list(onnx_config.outputs.keys()))
["last_hidden_state"]
Hub์ TensorFlow ์ฒดํฌํฌ์ธํธ์ ๊ฒฝ์ฐ์๋ ๊ณผ์ ์ ๋์ผํฉ๋๋ค. ์๋ฅผ ๋ค์ด, ๋ค์๊ณผ ๊ฐ์ด Keras organization์์ TensorFlow ์ฒดํฌํฌ์ธํธ๋ฅผ ๋ด๋ณด๋ผ ์ ์์ต๋๋ค:
python -m transformers.onnx --model=keras-io/transformers-qa onnx/
๋ก์ปฌ์ ์ ์ฅ๋ ๋ชจ๋ธ์ ๋ด๋ณด๋ด๋ ค๋ฉด ๋ชจ๋ธ์ ๊ฐ์ค์น ๋ฐ ํ ํฌ๋์ด์ ํ์ผ์ด ์ ์ฅ๋ ๋๋ ํ ๋ฆฌ๊ฐ ํ์ํฉ๋๋ค. ์๋ฅผ ๋ค์ด, ๋ค์๊ณผ ๊ฐ์ด ์ฒดํฌํฌ์ธํธ๋ฅผ ๊ฐ์ ธ์ค๊ณ ์ ์ฅํ ์ ์์ต๋๋ค:
>>> from transformers import AutoTokenizer, AutoModelForSequenceClassification
>>> # Load tokenizer and PyTorch weights form the Hub
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
>>> pt_model = AutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
>>> # Save to disk
>>> tokenizer.save_pretrained("local-pt-checkpoint")
>>> pt_model.save_pretrained("local-pt-checkpoint")
์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํ ํ, transformers.onnx
ํจํค์ง์ --model
์ธ์๋ฅผ ์ํ๋ ๋๋ ํ ๋ฆฌ๋ก ์ง์ ํ์ฌ ONNX๋ก ๋ด๋ณด๋ผ ์ ์์ต๋๋ค:
python -m transformers.onnx --model=local-pt-checkpoint onnx/
>>> from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
>>> # Load tokenizer and TensorFlow weights from the Hub
>>> tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")
>>> tf_model = TFAutoModelForSequenceClassification.from_pretrained("distilbert-base-uncased")
>>> # Save to disk
>>> tokenizer.save_pretrained("local-tf-checkpoint")
>>> tf_model.save_pretrained("local-tf-checkpoint")
์ฒดํฌํฌ์ธํธ๋ฅผ ์ ์ฅํ ํ, transformers.onnx
ํจํค์ง์ --model
์ธ์๋ฅผ ์ํ๋ ๋๋ ํ ๋ฆฌ๋ก ์ง์ ํ์ฌ ONNX๋ก ๋ด๋ณด๋ผ ์ ์์ต๋๋ค:
python -m transformers.onnx --model=local-tf-checkpoint onnx/
๋ค๋ฅธ ๋ชจ๋ธ ์์ ์ ๋ํ ๊ธฐ๋ฅ ์ ํ[[selecting-features-for-different-model-tasks]]
์ด์ ๋ชจ๋ธ์ ๋ด๋ณด๋ผ ๋ optimum.exporters.onnx
๋ฅผ ์ฌ์ฉํ๋๋ก ๊ถ์ฅํฉ๋๋ค. ์์
์ ์ ํํ๋ ๋ฐฉ๋ฒ์ ์์๋ณด๋ ค๋ฉด ๐ค Optimum ๋ฌธ์๋ฅผ ํ์ธํ์ธ์.
๋ค๋ฅธ ์ ํ์ ํ์คํฌ์ ๋ง์ถฐ์ ๋ชจ๋ธ์ ๋ด๋ณด๋ผ ์ ์๋๋ก ๋ฏธ๋ฆฌ ์ ์๋ ๊ตฌ์ฑ๋ง๋ค ์ผ๋ จ์ _๊ธฐ๋ฅ_์ด ํฌํจ๋์ด ์์ต๋๋ค. ์๋ ํ์ ๋์ ์๋๋๋ก ๊ฐ ๊ธฐ๋ฅ์ ๋ค๋ฅธ AutoClass
์ ์ฐ๊ด๋์ด ์์ต๋๋ค.
Feature | Auto Class |
---|---|
causal-lm , causal-lm-with-past |
AutoModelForCausalLM |
default , default-with-past |
AutoModel |
masked-lm |
AutoModelForMaskedLM |
question-answering |
AutoModelForQuestionAnswering |
seq2seq-lm , seq2seq-lm-with-past |
AutoModelForSeq2SeqLM |
sequence-classification |
AutoModelForSequenceClassification |
token-classification |
AutoModelForTokenClassification |
๊ฐ ๊ตฌ์ฑ์์ [~transformers.onnx.FeaturesManager
]๋ฅผ ํตํด ์ง์๋๋ ๊ธฐ๋ฅ ๋ชฉ๋ก์ ์ฐพ์ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, DistilBERT์ ๊ฒฝ์ฐ ๋ค์๊ณผ ๊ฐ์ต๋๋ค:
>>> from transformers.onnx.features import FeaturesManager
>>> distilbert_features = list(FeaturesManager.get_supported_features_for_model_type("distilbert").keys())
>>> print(distilbert_features)
["default", "masked-lm", "causal-lm", "sequence-classification", "token-classification", "question-answering"]
๊ทธ๋ฐ ๋ค์ transformers.onnx
ํจํค์ง์ --feature
์ธ์์ ์ด๋ฌํ ๊ธฐ๋ฅ ์ค ํ๋๋ฅผ ์ ๋ฌํ ์ ์์ต๋๋ค. ์๋ฅผ ๋ค์ด, ํ
์คํธ ๋ถ๋ฅ ๋ชจ๋ธ์ ๋ด๋ณด๋ด๋ ค๋ฉด ๋ค์๊ณผ ๊ฐ์ด Hub์์ ๋ฏธ์ธ ์กฐ์ ๋ ๋ชจ๋ธ์ ์ ํํ๊ณ ์คํํ ์ ์์ต๋๋ค:
python -m transformers.onnx --model=distilbert-base-uncased-finetuned-sst-2-english \
--feature=sequence-classification onnx/
๋ค์๊ณผ ๊ฐ์ ๋ก๊ทธ๊ฐ ํ์๋ฉ๋๋ค:
Validating ONNX model...
-[โ] ONNX model output names match reference model ({'logits'})
- Validating ONNX Model output "logits":
-[โ] (2, 2) matches (2, 2)
-[โ] all values close (atol: 1e-05)
All good, model saved at: onnx/model.onnx
์ด๋ ๋ฏธ์ธ ์กฐ์ ๋ ๋ชจ๋ธ์ ์ถ๋ ฅ๋ช
์ ์ด์ ์ distilbert-base-uncased
์ฒดํฌํฌ์ธํธ์์ ๋ดค๋ last_hidden_state
์ ๋ฌ๋ฆฌ logits
์
๋๋ค. ์ํ์ค ๋ถ๋ฅ๋ฅผ ์ํด ๋ฏธ์ธ ์กฐ์ ๋์๊ธฐ ๋๋ฌธ์ ์์๋๋ก ์
๋๋ค.
with-past
์ ๋ฏธ์ฌ๋ฅผ ๊ฐ์ง ๊ธฐ๋ฅ(์: causal-lm-with-past
)์ ๋ฏธ๋ฆฌ ๊ณ์ฐ๋ ์จ๊ฒจ์ง ์ํ(hidden states; ์ดํ
์
๋ธ๋ก ์ ํค-๊ฐ ์)๋ฅผ ์ฌ์ฉํ์ฌ ๋น ๋ฅธ ์๊ธฐ ํ๊ท ๋์ฝ๋ฉ์ด ๊ฐ๋ฅํ ๋ชจ๋ธ ํด๋์ค๋ค์
๋๋ค.
VisionEncoderDecoder
์ ํ ๋ชจ๋ธ์ ๊ฒฝ์ฐ, ์ธ์ฝ๋ ๋ฐ ๋์ฝ๋ ๋ถ๋ถ์ ๊ฐ๊ฐ encoder_model.onnx
๋ฐ decoder_model.onnx
๋ผ๋ ๋ ๊ฐ์ ONNX ํ์ผ๋ก ๋ถ๋ฆฌํ์ฌ ๋ด๋ณด๋
๋๋ค.
์ง์๋์ง ์๋ ์ํคํ ์ฒ๋ฅผ ์ํ ๋ชจ๋ธ ๋ด๋ณด๋ด๊ธฐ[[exporting-a-model-for-an-unsupported-architecture]]
ํ์ฌ ๋ด๋ณด๋ผ ์ ์๋ ๋ชจ๋ธ์ ์ง์ํ๋๋ก ๊ธฐ์ฌํ๋ ค๋ฉด ๋จผ์ optimum.exporters.onnx
์์ ์ง์๋๋์ง ํ์ธํ๊ณ ์ง์๋์ง ์๋ ๊ฒฝ์ฐ ๐ค Optimum์ ๊ธฐ์ฌํ์ธ์.
๋ผ์ด๋ธ๋ฌ๋ฆฌ์์ ์ง์ ์ง์ํ์ง ์๋ ์ํคํ ์ฒ์ ๋ชจ๋ธ์ ๋ด๋ณด๋ด๋ ค๋ฉด ์ธ ๊ฐ์ง ์ฃผ์ ๋จ๊ณ๋ฅผ ๊ฑฐ์ณ์ผ ํฉ๋๋ค:
- ์ฌ์ฉ์ ์ ์ ONNX ๊ตฌ์ฑ์ ๊ตฌํํ๊ธฐ
- ๋ชจ๋ธ์ ONNX๋ก ๋ด๋ณด๋ด๊ธฐ
- PyTorch ๋ฐ ๋ด๋ณด๋ธ ๋ชจ๋ธ์ ์ถ๋ ฅ ๊ฒ์ฆํ๊ธฐ
์ด ์น์ ์์๋ DistilBERT๊ฐ ์ด๋ป๊ฒ ๊ตฌํ๋์๋์ง ๊ฐ ๋จ๊ณ๋ง๋ค ์์ธํ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
์ฌ์ฉ์ ์ ์ ONNX ๊ตฌ์ฑ์ ๊ตฌํํ๊ธฐ[[implementing-a-custom-onnx-configuration]]
ONNX ๊ตฌ์ฑ ๊ฐ์ฒด๋ถํฐ ์์ํด ๋ด ์๋ค. ๋ด๋ณด๋ด๋ ค๋ ๋ชจ๋ธ ์ํคํ ์ฒ ์ ํ์ ๋ฐ๋ผ ์์ํด์ผํ๋ ์ธ ๊ฐ์ง ์ถ์ ํด๋์ค๋ฅผ ์ ๊ณตํฉ๋๋ค:
- ์ธ์ฝ๋ ๊ธฐ๋ฐ ๋ชจ๋ธ์ [
~onnx.config.OnnxConfig
]๋ฅผ ์์ํฉ๋๋ค. - ๋์ฝ๋ ๊ธฐ๋ฐ ๋ชจ๋ธ์ [
~onnx.config.OnnxConfigWithPast
]๋ฅผ ์์ํฉ๋๋ค. - ์ธ์ฝ๋-๋์ฝ๋ ๋ชจ๋ธ์ [
~onnx.config.OnnxSeq2SeqConfigWithPast
]๋ฅผ ์์ํฉ๋๋ค.
์ฌ์ฉ์ ์ ์ ONNX ๊ตฌ์ฑ์ ๊ตฌํํ๋ ์ข์ ๋ฐฉ๋ฒ์ ๋น์ทํ ์ํคํ
์ฒ์ configuration_<model_name>.py
ํ์ผ์์ ๊ธฐ์กด ๊ตฌํ์ ํ์ธํ๋ ๊ฒ์
๋๋ค.
DistilBERT๋ ์ธ์ฝ๋ ๊ธฐ๋ฐ ๋ชจ๋ธ์ด๋ฏ๋ก ํด๋น ๊ตฌ์ฑ์ OnnxConfig
๋ฅผ ์์ํฉ๋๋ค.
>>> from typing import Mapping, OrderedDict
>>> from transformers.onnx import OnnxConfig
>>> class DistilBertOnnxConfig(OnnxConfig):
... @property
... def inputs(self) -> Mapping[str, Mapping[int, str]]:
... return OrderedDict(
... [
... ("input_ids", {0: "batch", 1: "sequence"}),
... ("attention_mask", {0: "batch", 1: "sequence"}),
... ]
... )
๊ฐ ๊ตฌ์ฑ ๊ฐ์ฒด๋ inputs
์์ฑ์ ๊ตฌํํ๊ณ ๋งคํ์ ๋ฐํํด์ผ ํฉ๋๋ค. ๋งคํ์ ํค๋ ์์ ์
๋ ฅ์ ํด๋นํ๊ณ ๊ฐ์ ํด๋น ์
๋ ฅ์ ์ถ์ ๋ํ๋
๋๋ค. DistilBERT์ ๊ฒฝ์ฐ input_ids
๋ฐ attention_mask
๋ ๊ฐ์ ์
๋ ฅ์ด ํ์ํ๋ฐ์. ๋ ์
๋ ฅ ๋ชจ๋ (batch_size, sequence_length)
์ ๋์ผํ ์ฐจ์์ด๊ธฐ ๋๋ฌธ์ ๊ตฌ์ฑ์์๋ ๋๊ฐ์ ์ถ์ ์ฌ์ฉํฉ๋๋ค.
DistilBertOnnxConfig
์ inputs
์์ฑ์ด OrderedDict
๋ผ๋ ๊ฒ์ ์ ์ํ์ธ์. ์ด๋ ๊ฒ ํ๋ฉด ์
๋ ฅ์ด ๊ทธ๋ํ๋ฅผ ๋ฐ๋ผ ํ๋ฅผ ๋ PreTrainedModel.forward()
๋ฉ์๋ ์ ์๋ง์ ์๋์ ์ธ ์์น์ ์๋๋ก ๋ณด์ฅํฉ๋๋ค. ์ฌ์ฉ์ ์ ์ ONNX ๊ตฌ์ฑ์ ๊ตฌํํ ๋๋ inputs
๋ฐ outputs
์์ฑ์ผ๋ก OrderedDict
๋ฅผ ์ฌ์ฉํ๋ ๊ฒ์ ๊ถ์ฅํฉ๋๋ค.
ONNX ๊ตฌ์ฑ์ ๊ตฌํํ ํ์๋ ๋ค์๊ณผ ๊ฐ์ด ๊ธฐ๋ณธ ๋ชจ๋ธ์ ๊ตฌ์ฑ์ ์ ๊ณตํ์ฌ ์ธ์คํด์คํ ํ ์ ์์ต๋๋ค:
>>> from transformers import AutoConfig
>>> config = AutoConfig.from_pretrained("distilbert-base-uncased")
>>> onnx_config = DistilBertOnnxConfig(config)
๊ฒฐ๊ณผ ๊ฐ์ฒด์๋ ์ฌ๋ฌ ๊ฐ์ง ์ ์ฉํ ์์ฑ์ด ์์ต๋๋ค. ์๋ฅผ ๋ค์ด ONNX๋ก ๋ด๋ณด๋ผ ๋ ์ฐ์ผ ONNX ์ฐ์ฐ์ ์งํฉ์ ๋ณผ ์ ์์ต๋๋ค:
>>> print(onnx_config.default_onnx_opset)
11
๋ค์๊ณผ ๊ฐ์ด ๋ชจ๋ธ์ ์ฐ๊ฒฐ๋ ์ถ๋ ฅ์ ๋ณผ ์๋ ์์ต๋๋ค:
>>> print(onnx_config.outputs)
OrderedDict([("last_hidden_state", {0: "batch", 1: "sequence"})])
์ถ๋ ฅ ์์ฑ์ด ์
๋ ฅ๊ณผ ๋์ผํ ๊ตฌ์กฐ์์ ์ ์ํ์ธ์. ๊ฐ ์ถ๋ ฅ์ ์ด๋ฆ๊ณผ ์ฐจ์์ด OrderedDict
์ ํค-๊ฐ์ผ๋ก ์ ์ฅ๋์ด ์์ต๋๋ค. ์ถ๋ ฅ ๊ตฌ์กฐ๋ ๊ตฌ์ฑ์ ์ด๊ธฐํํ ๋ ์ ํํ ๊ธฐ๋ฅ๊ณผ ๊ด๋ จ์ด ์์ต๋๋ค. ๊ธฐ๋ณธ์ ์ผ๋ก ONNX ๊ตฌ์ฑ์ AutoModel
ํด๋์ค๋ก ๊ฐ์ ธ์จ ๋ชจ๋ธ์ ๋ด๋ณด๋ผ ๋ ์ฐ์ด๋ default
๊ธฐ๋ฅ์ผ๋ก ์ด๊ธฐํ๋ฉ๋๋ค. ๋ค๋ฅธ ํ์คํฌ๋ฅผ ์ํด ๋ชจ๋ธ์ ๋ด๋ณด๋ด๋ ค๋ฉด ONNX ๊ตฌ์ฑ์ ์ด๊ธฐํํ ๋ task
์ธ์์ ๋ค๋ฅธ ๊ธฐ๋ฅ์ ๋ฃ์ผ๋ฉด ๋ฉ๋๋ค. ์๋ฅผ ๋ค์ด, ์ํ์ค ๋ถ๋ฅ ๋จ๊ณ๋ฅผ ๋ง๋ถ์ธ DistilBERT๋ฅผ ๋ด๋ณด๋ด๋ ค๋ฉด, ์ด๋ ๊ฒ ํด๋ณผ ์ ์์ต๋๋ค:
>>> from transformers import AutoConfig
>>> config = AutoConfig.from_pretrained("distilbert-base-uncased")
>>> onnx_config_for_seq_clf = DistilBertOnnxConfig(config, task="sequence-classification")
>>> print(onnx_config_for_seq_clf.outputs)
OrderedDict([('logits', {0: 'batch'})])
[~onnx.config.OnnxConfig
]๋ ๋ค๋ฅธ ๊ตฌ์ฑ ํด๋์ค์ ์ฐ๊ฒฐ๋ ๋ชจ๋ ๊ธฐ๋ณธ ์์ฑ ๋ฐ ๋ฉ์๋๋ ํ์์ ๋ฐ๋ผ ๋ชจ๋ ์ฌ์ ์ํ ์ ์์ต๋๋ค. ๊ณ ๊ธ ์์ ๋ก [BartOnnxConfig
]๋ฅผ ํ์ธํ์ธ์.
๋ชจ๋ธ ๋ด๋ณด๋ด๊ธฐ[[exporting-the-model]]
ONNX ๊ตฌ์ฑ์ ๊ตฌํํ๋ค๋ฉด, ๋ค์ ๋จ๊ณ๋ ๋ชจ๋ธ์ ๋ด๋ณด๋ด๋ ๊ฒ์
๋๋ค. ์ด์ transformers.onnx
ํจํค์ง์์ ์ ๊ณตํ๋ export()
ํจ์๋ฅผ ์ดํด๋ณด๊ฒ ์ต๋๋ค. ์ด ํจ์๋ ONNX ๊ตฌ์ฑ, ๊ธฐ๋ณธ ๋ชจ๋ธ, ํ ํฌ๋์ด์ , ๊ทธ๋ฆฌ๊ณ ๋ด๋ณด๋ผ ํ์ผ์ ๊ฒฝ๋ก๋ฅผ ์
๋ ฅ์ผ๋ก ๋ฐ์ต๋๋ค:
>>> from pathlib import Path
>>> from transformers.onnx import export
>>> from transformers import AutoTokenizer, AutoModel
>>> onnx_path = Path("model.onnx")
>>> model_ckpt = "distilbert-base-uncased"
>>> base_model = AutoModel.from_pretrained(model_ckpt)
>>> tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
>>> onnx_inputs, onnx_outputs = export(tokenizer, base_model, onnx_config, onnx_config.default_onnx_opset, onnx_path)
export()
ํจ์๊ฐ ๋ฐํํ๋ onnx_inputs
์ onnx_outputs
๋ ๊ตฌ์ฑ์ inputs
์ outputs
์์ฑ์์ ์ ์๋ ํค ๋ชฉ๋ก์
๋๋ค. ๋ชจ๋ธ์ ๋ด๋ณด๋ธ ํ ๋ค์๊ณผ ๊ฐ์ด ๋ชจ๋ธ์ด ์ ๊ตฌ์ฑ๋์ด ์๋์ง ํ
์คํธํ ์ ์์ต๋๋ค:
>>> import onnx
>>> onnx_model = onnx.load("model.onnx")
>>> onnx.checker.check_model(onnx_model)
๋ชจ๋ธ ํฌ๊ธฐ๊ฐ 2GB๋ณด๋ค ํฐ ๊ฒฝ์ฐ ๋ด๋ณด๋ด๋ ์ค์ ์ฌ๋ฌ ์ถ๊ฐ ํ์ผ๋ค์ด ์์ฑ๋๋ ๊ฒ์ ๋ณผ ์ ์์ต๋๋ค. ์ฌ์ค ONNX๋ ๋ชจ๋ธ์ ์ ์ฅํ๊ธฐ ์ํด Protocol Buffers๋ฅผ ์ฌ์ฉํ๋๋ฐ, ๋ฒํผ๋ 2GB์ ํฌ๊ธฐ ์ ํ์ด ์๊ธฐ ๋๋ฌธ์ ์์ฐ์ค๋ฌ์ด ์ผ์ ๋๋ค. ์ธ๋ถ ๋ฐ์ดํฐ๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ๊ฐ์ ธ์ค๋ ๋ฐฉ๋ฒ์ ONNX ๋ฌธ์๋ฅผ ์ฐธ์กฐํ์ธ์.
๋ชจ๋ธ์ ์ถ๋ ฅ ๊ฒ์ฆํ๊ธฐ[[validating-the-model-outputs]]
๋ง์ง๋ง ๋จ๊ณ๋ ๊ธฐ์กด ๋ชจ๋ธ๊ณผ ๋ด๋ณด๋ธ ๋ชจ๋ธ์ ์ถ๋ ฅ์ด ์ผ์ ํ ์ค์ฐจ ๋ฒ์ ๋ด์์ ๋์ผํ๋ค๋ ๊ฒ์ ๊ฒ์ฆํ๋ ๊ฒ์
๋๋ค. ๊ทธ๋ฌ๋ ค๋ฉด transformers.onnx
ํจํค์ง์์ ์ ๊ณตํ๋ validate_model_outputs()
ํจ์๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค:
>>> from transformers.onnx import validate_model_outputs
>>> validate_model_outputs(
... onnx_config, tokenizer, base_model, onnx_path, onnx_outputs, onnx_config.atol_for_validation
... )
์ด ํจ์๋ [~transformers.onnx.OnnxConfig.generate_dummy_inputs
] ๋ฉ์๋๋ก ๊ธฐ์กด ๋ฐ ๋ด๋ณด๋ธ ๋ชจ๋ธ์ ์
๋ ฅ์ ์์ฑํ๋ฉฐ, ๊ฒ์ฆ์ ์ฌ์ฉ๋ ์ค์ฐจ ๋ฒ์๋ ๊ตฌ์ฑ์์ ์ ์ํ ์ ์์ต๋๋ค. ์ผ๋ฐ์ ์ผ๋ก๋ 1e-6์์ 1e-4 ๋ฒ์ ๋ด์์ ํฉ์ํ์ง๋ง, 1e-3๋ณด๋ค ์๋ค๋ฉด ๋ฌธ์ ์์ ๊ฐ๋ฅ์ฑ์ด ๋์ต๋๋ค.
๐ค Transformers์ ์ ๊ตฌ์ฑ ์ถ๊ฐํ๊ธฐ[[contributing-a-new-configuration-to-transformers]]
๋ฏธ๋ฆฌ ์ ์๋ ๊ตฌ์ฑ์ ์ซ์๋ฅผ ๋๋ฆฌ๋ ค๊ณ ๋ ธ๋ ฅํ๊ณ ์์ผ๋ฉฐ, ์ปค๋ฎค๋ํฐ์ ๊ธฐ์ฌ๋ฅผ ํ์ํฉ๋๋ค! ๋ผ์ด๋ธ๋ฌ๋ฆฌ์ ๋น์ ๋ง์ ๊ตฌ์ฑ์ ์ถ๊ฐํ๋ ค๋ฉด ๋ค์ ๋จ๊ณ๋ฅผ ๊ธฐ์ตํด์ฃผ์ธ์:
configuration_<model_name>.py
ํ์ผ์ ONNX ๊ตฌ์ฑ์ ๊ตฌํํ์ธ์.- [
~onnx.features.FeatureManager
]์ ๋ชจ๋ธ ์ํคํ ์ฒ ๋ฐ ํด๋น ๊ธฐ๋ฅ์ ํฌํจํ์ธ์. test_onnx_v2.py
์ ํ ์คํธ์ ๋ชจ๋ธ ์ํคํ ์ฒ๋ฅผ ์ถ๊ฐํ์ธ์.
์์ง ๊ฐ์ด ์ ์กํ์ ๋ค๋ฉด, IBERT ๊ตฌ์ฑ์ด ์ด๋ป๊ฒ ๊ธฐ์ฌ๋์๋์ง ํ์ธํด๋ณด์ธ์.