TorchScript๋ก ๋ด๋ณด๋ด๊ธฐ[[export-to-torchscript]]
TorchScript๋ฅผ ํ์ฉํ ์คํ์ ์์ง ์ด๊ธฐ ๋จ๊ณ๋ก, ๊ฐ๋ณ์ ์ธ ์ ๋ ฅ ํฌ๊ธฐ ๋ชจ๋ธ๋ค์ ํตํด ๊ทธ ๊ธฐ๋ฅ์ฑ์ ๊ณ์ ํ๊ตฌํ๊ณ ์์ต๋๋ค. ์ด ๊ธฐ๋ฅ์ ์ ํฌ๊ฐ ๊ด์ฌ์ ๋๊ณ ์๋ ๋ถ์ผ ์ค ํ๋์ด๋ฉฐ, ์์ผ๋ก ์ถ์๋ ๋ฒ์ ์์ ๋ ๋ง์ ์ฝ๋ ์์ , ๋ ์ ์ฐํ ๊ตฌํ, ๊ทธ๋ฆฌ๊ณ Python ๊ธฐ๋ฐ ์ฝ๋์ ์ปดํ์ผ๋ TorchScript๋ฅผ ๋น๊ตํ๋ ๋ฒค์น๋งํฌ๋ฅผ ๋ฑ์ ํตํด ๋ถ์์ ์ฌํํ ์์ ์ ๋๋ค.
TorchScript ๋ฌธ์์์๋ ์ด๋ ๊ฒ ๋งํฉ๋๋ค.
TorchScript๋ PyTorch ์ฝ๋์์ ์ง๋ ฌํ ๋ฐ ์ต์ ํ ๊ฐ๋ฅํ ๋ชจ๋ธ์ ์์ฑํ๋ ๋ฐฉ๋ฒ์ ๋๋ค.
JIT๊ณผ TRACE๋ ๊ฐ๋ฐ์๊ฐ ๋ชจ๋ธ์ ๋ด๋ณด๋ด์ ํจ์จ ์งํฅ์ ์ธ C++ ํ๋ก๊ทธ๋จ๊ณผ ๊ฐ์ ๋ค๋ฅธ ํ๋ก๊ทธ๋จ์์ ์ฌ์ฌ์ฉํ ์ ์๋๋ก ํ๋ PyTorch ๋ชจ๋์ ๋๋ค.
PyTorch ๊ธฐ๋ฐ Python ํ๋ก๊ทธ๋จ๊ณผ ๋ค๋ฅธ ํ๊ฒฝ์์ ๋ชจ๋ธ์ ์ฌ์ฌ์ฉํ ์ ์๋๋ก, ๐ค Transformers ๋ชจ๋ธ์ TorchScript๋ก ๋ด๋ณด๋ผ ์ ์๋ ์ธํฐํ์ด์ค๋ฅผ ์ ๊ณตํฉ๋๋ค. ์ด ๋ฌธ์์์๋ TorchScript๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ๋ด๋ณด๋ด๊ณ ์ฌ์ฉํ๋ ๋ฐฉ๋ฒ์ ์ค๋ช ํฉ๋๋ค.
๋ชจ๋ธ์ ๋ด๋ณด๋ด๋ ค๋ฉด ๋ ๊ฐ์ง๊ฐ ํ์ํฉ๋๋ค:
torchscript
ํ๋๊ทธ๋ก ๋ชจ๋ธ ์ธ์คํด์คํ- ๋๋ฏธ ์ ๋ ฅ์ ์ฌ์ฉํ ์์ ํ(forward pass)
์ด ํ์ ์กฐ๊ฑด๋ค์ ์๋์ ์์ธํ ์ค๋ช ๋ ๊ฒ์ฒ๋ผ ๊ฐ๋ฐ์๋ค์ด ์ฃผ์ํด์ผ ํ ์ฌ๋ฌ ์ฌํญ๋ค์ ์๋ฏธํฉ๋๋ค.
TorchScript ํ๋๊ทธ์ ๋ฌถ์ธ ๊ฐ์ค์น(tied weights)[[torchscript-flag-and-tied-weights]]
torchscript
ํ๋๊ทธ๊ฐ ํ์ํ ์ด์ ๋ ๋๋ถ๋ถ์ ๐ค Transformers ์ธ์ด ๋ชจ๋ธ์์ Embedding
๋ ์ด์ด์ Decoding
๋ ์ด์ด ๊ฐ์ ๋ฌถ์ธ ๊ฐ์ค์น(tied weights)๊ฐ ์กด์ฌํ๊ธฐ ๋๋ฌธ์
๋๋ค.
TorchScript๋ ๋ฌถ์ธ ๊ฐ์ค์น๋ฅผ ๊ฐ์ง ๋ชจ๋ธ์ ๋ด๋ณด๋ผ ์ ์์ผ๋ฏ๋ก, ๋ฏธ๋ฆฌ ๊ฐ์ค์น๋ฅผ ํ๊ณ ๋ณต์ ํด์ผ ํฉ๋๋ค.
torchscript
ํ๋๊ทธ๋ก ์ธ์คํด์คํ๋ ๋ชจ๋ธ์ Embedding
๋ ์ด์ด์ Decoding
๋ ์ด์ด๊ฐ ๋ถ๋ฆฌ๋์ด ์์ผ๋ฏ๋ก ์ดํ์ ํ๋ จํด์๋ ์ ๋ฉ๋๋ค.
ํ๋ จ์ ํ๊ฒ ๋๋ฉด ๋ ๋ ์ด์ด ๊ฐ ๋๊ธฐํ๊ฐ ํด์ ๋์ด ์์์น ๋ชปํ ๊ฒฐ๊ณผ๊ฐ ๋ฐ์ํ ์ ์์ต๋๋ค.
์ธ์ด ๋ชจ๋ธ ํค๋๋ฅผ ๊ฐ์ง ์์ ๋ชจ๋ธ์ ๊ฐ์ค์น๊ฐ ๋ฌถ์ฌ ์์ง ์์์ ์ด ๋ฌธ์ ๊ฐ ๋ฐ์ํ์ง ์์ต๋๋ค.
์ด๋ฌํ ๋ชจ๋ธ๋ค์ torchscript
ํ๋๊ทธ ์์ด ์์ ํ๊ฒ ๋ด๋ณด๋ผ ์ ์์ต๋๋ค.
๋๋ฏธ ์ ๋ ฅ๊ณผ ํ์ค ๊ธธ์ด[[dummy-inputs-and-standard-lengths]]
๋๋ฏธ ์ ๋ ฅ(dummy inputs)์ ๋ชจ๋ธ์ ์์ ํ(forward pass)์ ์ฌ์ฉ๋ฉ๋๋ค. ์ ๋ ฅ ๊ฐ์ด ๋ ์ด์ด๋ฅผ ํตํด ์ ํ๋๋ ๋์, PyTorch๋ ๊ฐ ํ ์์์ ์คํ๋ ๋ค๋ฅธ ์ฐ์ฐ์ ์ถ์ ํฉ๋๋ค. ์ด๋ฌํ ๊ธฐ๋ก๋ ์ฐ์ฐ์ ๋ชจ๋ธ์ *์ถ์ (trace)*์ ์์ฑํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค.
์ถ์ ์ ์ ๋ ฅ์ ์ฐจ์์ ๊ธฐ์ค์ผ๋ก ์์ฑ๋ฉ๋๋ค. ๋ฐ๋ผ์ ๋๋ฏธ ์ ๋ ฅ์ ์ฐจ์์ ์ ํ๋์ด, ๋ค๋ฅธ ์ํ์ค ๊ธธ์ด๋ ๋ฐฐ์น ํฌ๊ธฐ์์๋ ์๋ํ์ง ์์ต๋๋ค. ๋ค๋ฅธ ํฌ๊ธฐ๋ก ์๋ํ ๊ฒฝ์ฐ ๋ค์๊ณผ ๊ฐ์ ์ค๋ฅ๊ฐ ๋ฐ์ํฉ๋๋ค:
`The expanded size of the tensor (3) must match the existing size (7) at non-singleton dimension 2`
์ถ๋ก ์ค ๋ชจ๋ธ์ ๊ณต๊ธ๋ ๊ฐ์ฅ ํฐ ์ ๋ ฅ๋งํผ ํฐ ๋๋ฏธ ์ ๋ ฅ ํฌ๊ธฐ๋ก ๋ชจ๋ธ์ ์ถ์ ํ๋ ๊ฒ์ด ์ข์ต๋๋ค. ํจ๋ฉ์ ๋๋ฝ๋ ๊ฐ์ ์ฑ์ฐ๋ ๋ฐ ๋์์ด ๋ ์ ์์ต๋๋ค. ๊ทธ๋ฌ๋ ๋ชจ๋ธ์ด ๋ ํฐ ์ ๋ ฅ ํฌ๊ธฐ๋ก ์ถ์ ๋๊ธฐ ๋๋ฌธ์, ํ๋ ฌ์ ์ฐจ์์ด ์ปค์ง๊ณ ๊ณ์ฐ๋์ด ๋ง์์ง๋๋ค.
๋ค์ํ ์ํ์ค ๊ธธ์ด ๋ชจ๋ธ์ ๋ด๋ณด๋ผ ๋๋ ๊ฐ ์ ๋ ฅ์ ๋ํด ์ํ๋๋ ์ด ์ฐ์ฐ ํ์์ ์ฃผ์ํ๊ณ ์ฑ๋ฅ์ ์ฃผ์ ๊น๊ฒ ํ์ธํ์ธ์.
Python์์ TorchScript ์ฌ์ฉํ๊ธฐ[[using-torchscript-in-python]]
์ด ์น์ ์์๋ ๋ชจ๋ธ์ ์ ์ฅํ๊ณ ๊ฐ์ ธ์ค๋ ๋ฐฉ๋ฒ, ์ถ์ ์ ์ฌ์ฉํ์ฌ ์ถ๋ก ํ๋ ๋ฐฉ๋ฒ์ ๋ณด์ฌ์ค๋๋ค.
๋ชจ๋ธ ์ ์ฅํ๊ธฐ[[saving-a-model]]
BertModel
์ TorchScript๋ก ๋ด๋ณด๋ด๋ ค๋ฉด BertConfig
ํด๋์ค์์ BertModel
์ ์ธ์คํด์คํํ ๋ค์, traced_bert.pt
๋ผ๋ ํ์ผ๋ช
์ผ๋ก ๋์คํฌ์ ์ ์ฅํ๋ฉด ๋ฉ๋๋ค.
from transformers import BertModel, BertTokenizer, BertConfig
import torch
enc = BertTokenizer.from_pretrained("bert-base-uncased")
# ์
๋ ฅ ํ
์คํธ ํ ํฐํํ๊ธฐ
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)
# ์
๋ ฅ ํ ํฐ ์ค ํ๋๋ฅผ ๋ง์คํนํ๊ธฐ
masked_index = 8
tokenized_text[masked_index] = "[MASK]"
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
# ๋๋ฏธ ์
๋ ฅ ๋ง๋ค๊ธฐ
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]
# torchscript ํ๋๊ทธ๋ก ๋ชจ๋ธ ์ด๊ธฐํํ๊ธฐ
# ์ด ๋ชจ๋ธ์ LM ํค๋๊ฐ ์์ผ๋ฏ๋ก ํ์ํ์ง ์์ง๋ง, ํ๋๊ทธ๋ฅผ True๋ก ์ค์ ํฉ๋๋ค.
config = BertConfig(
vocab_size_or_config_json_file=32000,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
torchscript=True,
)
# ๋ชจ๋ธ์ ์ธ์คํดํธํํ๊ธฐ
model = BertModel(config)
# ๋ชจ๋ธ์ ํ๊ฐ ๋ชจ๋๋ก ๋์ด์ผ ํฉ๋๋ค.
model.eval()
# ๋ง์ฝ *from_pretrained*๋ฅผ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ์ธ์คํด์คํํ๋ ๊ฒฝ์ฐ, TorchScript ํ๋๊ทธ๋ฅผ ์ฝ๊ฒ ์ค์ ํ ์ ์์ต๋๋ค
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
# ์ถ์ ์์ฑํ๊ธฐ
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
torch.jit.save(traced_model, "traced_bert.pt")
๋ชจ๋ธ ๊ฐ์ ธ์ค๊ธฐ[[loading-a-model]]
์ด์ ์ด์ ์ ์ ์ฅํ BertModel
, ์ฆ traced_bert.pt
๋ฅผ ๋์คํฌ์์ ๊ฐ์ ธ์ค๊ณ , ์ด์ ์ ์ด๊ธฐํํ dummy_input
์์ ์ฌ์ฉํ ์ ์์ต๋๋ค.
loaded_model = torch.jit.load("traced_bert.pt")
loaded_model.eval()
all_encoder_layers, pooled_output = loaded_model(*dummy_input)
์ถ์ ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ ์ถ๋ก ํ๊ธฐ[[using-a-traced-model-for-inference]]
__call__
์ด์ค ์ธ๋์ค์ฝ์ด(dunder) ๋ฉ์๋๋ฅผ ์ฌ์ฉํ์ฌ ์ถ๋ก ์ ์ถ์ ๋ ๋ชจ๋ธ์ ์ฌ์ฉํ์ธ์:
traced_model(tokens_tensor, segments_tensors)
Neuron SDK๋ก Hugging Face TorchScript ๋ชจ๋ธ์ AWS์ ๋ฐฐํฌํ๊ธฐ[[deploy-hugging-face-torchscript-models-to-aws-with-the-neuron-sdk]]
AWS๊ฐ ํด๋ผ์ฐ๋์์ ์ ๋น์ฉ, ๊ณ ์ฑ๋ฅ ๋จธ์ ๋ฌ๋ ์ถ๋ก ์ ์ํ Amazon EC2 Inf1 ์ธ์คํด์ค ์ ํ๊ตฐ์ ์ถ์ํ์ต๋๋ค. Inf1 ์ธ์คํด์ค๋ ๋ฅ๋ฌ๋ ์ถ๋ก ์ํฌ๋ก๋์ ํนํ๋ ๋ง์ถค ํ๋์จ์ด ๊ฐ์๊ธฐ์ธ AWS Inferentia ์นฉ์ผ๋ก ๊ตฌ๋๋ฉ๋๋ค. AWS Neuron์ Inferentia๋ฅผ ์ํ SDK๋ก, Inf1์ ๋ฐฐํฌํ๊ธฐ ์ํ transformers ๋ชจ๋ธ ์ถ์ ๋ฐ ์ต์ ํ๋ฅผ ์ง์ํฉ๋๋ค. Neuron SDK๋ ๋ค์๊ณผ ๊ฐ์ ๊ธฐ๋ฅ์ ์ ๊ณตํฉ๋๋ค:
- ์ฝ๋ ํ ์ค๋ง ๋ณ๊ฒฝํ๋ฉด ํด๋ผ์ฐ๋ ์ถ๋ก ๋ฅผ ์ํด TorchScript ๋ชจ๋ธ์ ์ถ์ ํ๊ณ ์ต์ ํํ ์ ์๋ ์ฌ์ด API
- ์ฆ์ ์ฌ์ฉ ๊ฐ๋ฅํ ์ฑ๋ฅ ์ต์ ํ๋ก ๋น์ฉ ํจ์จ ํฅ์
- PyTorch ๋๋ TensorFlow๋ก ๊ตฌ์ถ๋ Hugging Face transformers ๋ชจ๋ธ ์ง์
์์ฌ์ [[implications]]
BERT (Bidirectional Encoder Representations from Transformers) ์ํคํ ์ฒ ๋๋ ๊ทธ ๋ณํ์ธ distilBERT ๋ฐ roBERTa๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ํ Transformers ๋ชจ๋ธ์ ์ถ์ถ ๊ธฐ๋ฐ ์ง์์๋ต, ์ํ์ค ๋ถ๋ฅ ๋ฐ ํ ํฐ ๋ถ๋ฅ์ ๊ฐ์ ๋น์์ฑ ์์ ์ Inf1์์ ์ต์์ ์ฑ๋ฅ์ ๋ณด์ ๋๋ค. ๊ทธ๋ฌ๋ ํ ์คํธ ์์ฑ ์์ ๋ AWS Neuron MarianMT ํํ ๋ฆฌ์ผ์ ๋ฐ๋ผ Inf1์์ ์คํ๋๋๋ก ์กฐ์ ํ ์ ์์ต๋๋ค.
Inferentia์์ ๋ฐ๋ก ๋ณํํ ์ ์๋ ๋ชจ๋ธ์ ๋ํ ์์ธํ ์ ๋ณด๋ Neuron ๋ฌธ์์ Model Architecture Fit ์น์ ์์ ํ์ธํ ์ ์์ต๋๋ค.
์ข ์์ฑ[[dependencies]]
AWS Neuron์ ์ฌ์ฉํ์ฌ ๋ชจ๋ธ์ ๋ณํํ๋ ค๋ฉด Neuron SDK ํ๊ฒฝ์ด ํ์ํฉ๋๋ค. ์ด๋ AWS Deep Learning AMI์ ๋ฏธ๋ฆฌ ๊ตฌ์ฑ๋์ด ์์ต๋๋ค.
AWS Neuron์ผ๋ก ๋ชจ๋ธ ๋ณํํ๊ธฐ[[converting-a-model-for-aws-neuron]]
BertModel
์ ์ถ์ ํ๋ ค๋ฉด, Python์์ TorchScript ์ฌ์ฉํ๊ธฐ์์์ ๋์ผํ ์ฝ๋๋ฅผ ์ฌ์ฉํด์ AWS NEURON์ฉ ๋ชจ๋ธ์ ๋ณํํฉ๋๋ค.
torch.neuron
ํ๋ ์์ํฌ ์ต์คํ
์
์ ๊ฐ์ ธ์ Python API๋ฅผ ํตํด Neuron SDK์ ๊ตฌ์ฑ ์์์ ์ ๊ทผํฉ๋๋ค:
from transformers import BertModel, BertTokenizer, BertConfig
import torch
import torch.neuron
๋ค์ ์ค๋ง ์์ ํ๋ฉด ๋ฉ๋๋ค:
- torch.jit.trace(model, [tokens_tensor, segments_tensors])
+ torch.neuron.trace(model, [token_tensor, segments_tensors])
์ด๋ก์จ Neuron SDK๊ฐ ๋ชจ๋ธ์ ์ถ์ ํ๊ณ Inf1 ์ธ์คํด์ค์ ์ต์ ํํ ์ ์๊ฒ ๋ฉ๋๋ค.
AWS Neuron SDK์ ๊ธฐ๋ฅ, ๋๊ตฌ, ์์ ํํ ๋ฆฌ์ผ ๋ฐ ์ต์ ์ ๋ฐ์ดํธ์ ๋ํด ์์ธํ ์์๋ณด๋ ค๋ฉด AWS NeuronSDK ๋ฌธ์๋ฅผ ์ฐธ์กฐํ์ธ์.