MASR / transformers /docs /source /ko /torchscript.md
Yuvarraj's picture
Initial commit
a0db2f9

TorchScript๋กœ ๋‚ด๋ณด๋‚ด๊ธฐ[[export-to-torchscript]]

TorchScript๋ฅผ ํ™œ์šฉํ•œ ์‹คํ—˜์€ ์•„์ง ์ดˆ๊ธฐ ๋‹จ๊ณ„๋กœ, ๊ฐ€๋ณ€์ ์ธ ์ž…๋ ฅ ํฌ๊ธฐ ๋ชจ๋ธ๋“ค์„ ํ†ตํ•ด ๊ทธ ๊ธฐ๋Šฅ์„ฑ์„ ๊ณ„์† ํƒ๊ตฌํ•˜๊ณ  ์žˆ์Šต๋‹ˆ๋‹ค. ์ด ๊ธฐ๋Šฅ์€ ์ €ํฌ๊ฐ€ ๊ด€์‹ฌ์„ ๋‘๊ณ  ์žˆ๋Š” ๋ถ„์•ผ ์ค‘ ํ•˜๋‚˜์ด๋ฉฐ, ์•ž์œผ๋กœ ์ถœ์‹œ๋  ๋ฒ„์ „์—์„œ ๋” ๋งŽ์€ ์ฝ”๋“œ ์˜ˆ์ œ, ๋” ์œ ์—ฐํ•œ ๊ตฌํ˜„, ๊ทธ๋ฆฌ๊ณ  Python ๊ธฐ๋ฐ˜ ์ฝ”๋“œ์™€ ์ปดํŒŒ์ผ๋œ TorchScript๋ฅผ ๋น„๊ตํ•˜๋Š” ๋ฒค์น˜๋งˆํฌ๋ฅผ ๋“ฑ์„ ํ†ตํ•ด ๋ถ„์„์„ ์‹ฌํ™”ํ•  ์˜ˆ์ •์ž…๋‹ˆ๋‹ค.

TorchScript ๋ฌธ์„œ์—์„œ๋Š” ์ด๋ ‡๊ฒŒ ๋งํ•ฉ๋‹ˆ๋‹ค.

TorchScript๋Š” PyTorch ์ฝ”๋“œ์—์„œ ์ง๋ ฌํ™” ๋ฐ ์ตœ์ ํ™” ๊ฐ€๋Šฅํ•œ ๋ชจ๋ธ์„ ์ƒ์„ฑํ•˜๋Š” ๋ฐฉ๋ฒ•์ž…๋‹ˆ๋‹ค.

JIT๊ณผ TRACE๋Š” ๊ฐœ๋ฐœ์ž๊ฐ€ ๋ชจ๋ธ์„ ๋‚ด๋ณด๋‚ด์„œ ํšจ์œจ ์ง€ํ–ฅ์ ์ธ C++ ํ”„๋กœ๊ทธ๋žจ๊ณผ ๊ฐ™์€ ๋‹ค๋ฅธ ํ”„๋กœ๊ทธ๋žจ์—์„œ ์žฌ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก ํ•˜๋Š” PyTorch ๋ชจ๋“ˆ์ž…๋‹ˆ๋‹ค.

PyTorch ๊ธฐ๋ฐ˜ Python ํ”„๋กœ๊ทธ๋žจ๊ณผ ๋‹ค๋ฅธ ํ™˜๊ฒฝ์—์„œ ๋ชจ๋ธ์„ ์žฌ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ๋„๋ก, ๐Ÿค— Transformers ๋ชจ๋ธ์„ TorchScript๋กœ ๋‚ด๋ณด๋‚ผ ์ˆ˜ ์žˆ๋Š” ์ธํ„ฐํŽ˜์ด์Šค๋ฅผ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค. ์ด ๋ฌธ์„œ์—์„œ๋Š” TorchScript๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ๋‚ด๋ณด๋‚ด๊ณ  ์‚ฌ์šฉํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค.

๋ชจ๋ธ์„ ๋‚ด๋ณด๋‚ด๋ ค๋ฉด ๋‘ ๊ฐ€์ง€๊ฐ€ ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค:

  • torchscript ํ”Œ๋ž˜๊ทธ๋กœ ๋ชจ๋ธ ์ธ์Šคํ„ด์Šคํ™”
  • ๋”๋ฏธ ์ž…๋ ฅ์„ ์‚ฌ์šฉํ•œ ์ˆœ์ „ํŒŒ(forward pass)

์ด ํ•„์ˆ˜ ์กฐ๊ฑด๋“ค์€ ์•„๋ž˜์— ์ž์„ธํžˆ ์„ค๋ช…๋œ ๊ฒƒ์ฒ˜๋Ÿผ ๊ฐœ๋ฐœ์ž๋“ค์ด ์ฃผ์˜ํ•ด์•ผ ํ•  ์—ฌ๋Ÿฌ ์‚ฌํ•ญ๋“ค์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค.

TorchScript ํ”Œ๋ž˜๊ทธ์™€ ๋ฌถ์ธ ๊ฐ€์ค‘์น˜(tied weights)[[torchscript-flag-and-tied-weights]]

torchscript ํ”Œ๋ž˜๊ทธ๊ฐ€ ํ•„์š”ํ•œ ์ด์œ ๋Š” ๋Œ€๋ถ€๋ถ„์˜ ๐Ÿค— Transformers ์–ธ์–ด ๋ชจ๋ธ์—์„œ Embedding ๋ ˆ์ด์–ด์™€ Decoding ๋ ˆ์ด์–ด ๊ฐ„์˜ ๋ฌถ์ธ ๊ฐ€์ค‘์น˜(tied weights)๊ฐ€ ์กด์žฌํ•˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค. TorchScript๋Š” ๋ฌถ์ธ ๊ฐ€์ค‘์น˜๋ฅผ ๊ฐ€์ง„ ๋ชจ๋ธ์„ ๋‚ด๋ณด๋‚ผ ์ˆ˜ ์—†์œผ๋ฏ€๋กœ, ๋ฏธ๋ฆฌ ๊ฐ€์ค‘์น˜๋ฅผ ํ’€๊ณ  ๋ณต์ œํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

torchscript ํ”Œ๋ž˜๊ทธ๋กœ ์ธ์Šคํ„ด์Šคํ™”๋œ ๋ชจ๋ธ์€ Embedding ๋ ˆ์ด์–ด์™€ Decoding ๋ ˆ์ด์–ด๊ฐ€ ๋ถ„๋ฆฌ๋˜์–ด ์žˆ์œผ๋ฏ€๋กœ ์ดํ›„์— ํ›ˆ๋ จํ•ด์„œ๋Š” ์•ˆ ๋ฉ๋‹ˆ๋‹ค. ํ›ˆ๋ จ์„ ํ•˜๊ฒŒ ๋˜๋ฉด ๋‘ ๋ ˆ์ด์–ด ๊ฐ„ ๋™๊ธฐํ™”๊ฐ€ ํ•ด์ œ๋˜์–ด ์˜ˆ์ƒ์น˜ ๋ชปํ•œ ๊ฒฐ๊ณผ๊ฐ€ ๋ฐœ์ƒํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์–ธ์–ด ๋ชจ๋ธ ํ—ค๋“œ๋ฅผ ๊ฐ–์ง€ ์•Š์€ ๋ชจ๋ธ์€ ๊ฐ€์ค‘์น˜๊ฐ€ ๋ฌถ์—ฌ ์žˆ์ง€ ์•Š์•„์„œ ์ด ๋ฌธ์ œ๊ฐ€ ๋ฐœ์ƒํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๋ชจ๋ธ๋“ค์€ torchscript ํ”Œ๋ž˜๊ทธ ์—†์ด ์•ˆ์ „ํ•˜๊ฒŒ ๋‚ด๋ณด๋‚ผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

๋”๋ฏธ ์ž…๋ ฅ๊ณผ ํ‘œ์ค€ ๊ธธ์ด[[dummy-inputs-and-standard-lengths]]

๋”๋ฏธ ์ž…๋ ฅ(dummy inputs)์€ ๋ชจ๋ธ์˜ ์ˆœ์ „ํŒŒ(forward pass)์— ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค. ์ž…๋ ฅ ๊ฐ’์ด ๋ ˆ์ด์–ด๋ฅผ ํ†ตํ•ด ์ „ํŒŒ๋˜๋Š” ๋™์•ˆ, PyTorch๋Š” ๊ฐ ํ…์„œ์—์„œ ์‹คํ–‰๋œ ๋‹ค๋ฅธ ์—ฐ์‚ฐ์„ ์ถ”์ ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Ÿฌํ•œ ๊ธฐ๋ก๋œ ์—ฐ์‚ฐ์€ ๋ชจ๋ธ์˜ *์ถ”์ (trace)*์„ ์ƒ์„ฑํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.

์ถ”์ ์€ ์ž…๋ ฅ์˜ ์ฐจ์›์„ ๊ธฐ์ค€์œผ๋กœ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค. ๋”ฐ๋ผ์„œ ๋”๋ฏธ ์ž…๋ ฅ์˜ ์ฐจ์›์— ์ œํ•œ๋˜์–ด, ๋‹ค๋ฅธ ์‹œํ€€์Šค ๊ธธ์ด๋‚˜ ๋ฐฐ์น˜ ํฌ๊ธฐ์—์„œ๋Š” ์ž‘๋™ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. ๋‹ค๋ฅธ ํฌ๊ธฐ๋กœ ์‹œ๋„ํ•  ๊ฒฝ์šฐ ๋‹ค์Œ๊ณผ ๊ฐ™์€ ์˜ค๋ฅ˜๊ฐ€ ๋ฐœ์ƒํ•ฉ๋‹ˆ๋‹ค:

`The expanded size of the tensor (3) must match the existing size (7) at non-singleton dimension 2`

์ถ”๋ก  ์ค‘ ๋ชจ๋ธ์— ๊ณต๊ธ‰๋  ๊ฐ€์žฅ ํฐ ์ž…๋ ฅ๋งŒํผ ํฐ ๋”๋ฏธ ์ž…๋ ฅ ํฌ๊ธฐ๋กœ ๋ชจ๋ธ์„ ์ถ”์ ํ•˜๋Š” ๊ฒƒ์ด ์ข‹์Šต๋‹ˆ๋‹ค. ํŒจ๋”ฉ์€ ๋ˆ„๋ฝ๋œ ๊ฐ’์„ ์ฑ„์šฐ๋Š” ๋ฐ ๋„์›€์ด ๋  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ๋ชจ๋ธ์ด ๋” ํฐ ์ž…๋ ฅ ํฌ๊ธฐ๋กœ ์ถ”์ ๋˜๊ธฐ ๋•Œ๋ฌธ์—, ํ–‰๋ ฌ์˜ ์ฐจ์›์ด ์ปค์ง€๊ณ  ๊ณ„์‚ฐ๋Ÿ‰์ด ๋งŽ์•„์ง‘๋‹ˆ๋‹ค.

๋‹ค์–‘ํ•œ ์‹œํ€€์Šค ๊ธธ์ด ๋ชจ๋ธ์„ ๋‚ด๋ณด๋‚ผ ๋•Œ๋Š” ๊ฐ ์ž…๋ ฅ์— ๋Œ€ํ•ด ์ˆ˜ํ–‰๋˜๋Š” ์ด ์—ฐ์‚ฐ ํšŸ์ˆ˜์— ์ฃผ์˜ํ•˜๊ณ  ์„ฑ๋Šฅ์„ ์ฃผ์˜ ๊นŠ๊ฒŒ ํ™•์ธํ•˜์„ธ์š”.

Python์—์„œ TorchScript ์‚ฌ์šฉํ•˜๊ธฐ[[using-torchscript-in-python]]

์ด ์„น์…˜์—์„œ๋Š” ๋ชจ๋ธ์„ ์ €์žฅํ•˜๊ณ  ๊ฐ€์ ธ์˜ค๋Š” ๋ฐฉ๋ฒ•, ์ถ”์ ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ถ”๋ก ํ•˜๋Š” ๋ฐฉ๋ฒ•์„ ๋ณด์—ฌ์ค๋‹ˆ๋‹ค.

๋ชจ๋ธ ์ €์žฅํ•˜๊ธฐ[[saving-a-model]]

BertModel์„ TorchScript๋กœ ๋‚ด๋ณด๋‚ด๋ ค๋ฉด BertConfig ํด๋ž˜์Šค์—์„œ BertModel์„ ์ธ์Šคํ„ด์Šคํ™”ํ•œ ๋‹ค์Œ, traced_bert.pt๋ผ๋Š” ํŒŒ์ผ๋ช…์œผ๋กœ ๋””์Šคํฌ์— ์ €์žฅํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.

from transformers import BertModel, BertTokenizer, BertConfig
import torch

enc = BertTokenizer.from_pretrained("bert-base-uncased")

# ์ž…๋ ฅ ํ…์ŠคํŠธ ํ† ํฐํ™”ํ•˜๊ธฐ
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)

# ์ž…๋ ฅ ํ† ํฐ ์ค‘ ํ•˜๋‚˜๋ฅผ ๋งˆ์Šคํ‚นํ•˜๊ธฐ
masked_index = 8
tokenized_text[masked_index] = "[MASK]"
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]

# ๋”๋ฏธ ์ž…๋ ฅ ๋งŒ๋“ค๊ธฐ
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]

# torchscript ํ”Œ๋ž˜๊ทธ๋กœ ๋ชจ๋ธ ์ดˆ๊ธฐํ™”ํ•˜๊ธฐ
# ์ด ๋ชจ๋ธ์€ LM ํ—ค๋“œ๊ฐ€ ์—†์œผ๋ฏ€๋กœ ํ•„์š”ํ•˜์ง€ ์•Š์ง€๋งŒ, ํ”Œ๋ž˜๊ทธ๋ฅผ True๋กœ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
config = BertConfig(
    vocab_size_or_config_json_file=32000,
    hidden_size=768,
    num_hidden_layers=12,
    num_attention_heads=12,
    intermediate_size=3072,
    torchscript=True,
)

# ๋ชจ๋ธ์„ ์ธ์Šคํ„ดํŠธํ™”ํ•˜๊ธฐ
model = BertModel(config)

# ๋ชจ๋ธ์„ ํ‰๊ฐ€ ๋ชจ๋“œ๋กœ ๋‘์–ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.
model.eval()

# ๋งŒ์•ฝ *from_pretrained*๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ์ธ์Šคํ„ด์Šคํ™”ํ•˜๋Š” ๊ฒฝ์šฐ, TorchScript ํ”Œ๋ž˜๊ทธ๋ฅผ ์‰ฝ๊ฒŒ ์„ค์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)

# ์ถ”์  ์ƒ์„ฑํ•˜๊ธฐ
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
torch.jit.save(traced_model, "traced_bert.pt")

๋ชจ๋ธ ๊ฐ€์ ธ์˜ค๊ธฐ[[loading-a-model]]

์ด์ œ ์ด์ „์— ์ €์žฅํ•œ BertModel, ์ฆ‰ traced_bert.pt๋ฅผ ๋””์Šคํฌ์—์„œ ๊ฐ€์ ธ์˜ค๊ณ , ์ด์ „์— ์ดˆ๊ธฐํ™”ํ•œ dummy_input์—์„œ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

loaded_model = torch.jit.load("traced_bert.pt")
loaded_model.eval()

all_encoder_layers, pooled_output = loaded_model(*dummy_input)

์ถ”์ ๋œ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ถ”๋ก ํ•˜๊ธฐ[[using-a-traced-model-for-inference]]

__call__ ์ด์ค‘ ์–ธ๋”์Šค์ฝ”์–ด(dunder) ๋ฉ”์†Œ๋“œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ถ”๋ก ์— ์ถ”์ ๋œ ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•˜์„ธ์š”:

traced_model(tokens_tensor, segments_tensors)

Neuron SDK๋กœ Hugging Face TorchScript ๋ชจ๋ธ์„ AWS์— ๋ฐฐํฌํ•˜๊ธฐ[[deploy-hugging-face-torchscript-models-to-aws-with-the-neuron-sdk]]

AWS๊ฐ€ ํด๋ผ์šฐ๋“œ์—์„œ ์ €๋น„์šฉ, ๊ณ ์„ฑ๋Šฅ ๋จธ์‹  ๋Ÿฌ๋‹ ์ถ”๋ก ์„ ์œ„ํ•œ Amazon EC2 Inf1 ์ธ์Šคํ„ด์Šค ์ œํ’ˆ๊ตฐ์„ ์ถœ์‹œํ–ˆ์Šต๋‹ˆ๋‹ค. Inf1 ์ธ์Šคํ„ด์Šค๋Š” ๋”ฅ๋Ÿฌ๋‹ ์ถ”๋ก  ์›Œํฌ๋กœ๋“œ์— ํŠนํ™”๋œ ๋งž์ถค ํ•˜๋“œ์›จ์–ด ๊ฐ€์†๊ธฐ์ธ AWS Inferentia ์นฉ์œผ๋กœ ๊ตฌ๋™๋ฉ๋‹ˆ๋‹ค. AWS Neuron์€ Inferentia๋ฅผ ์œ„ํ•œ SDK๋กœ, Inf1์— ๋ฐฐํฌํ•˜๊ธฐ ์œ„ํ•œ transformers ๋ชจ๋ธ ์ถ”์  ๋ฐ ์ตœ์ ํ™”๋ฅผ ์ง€์›ํ•ฉ๋‹ˆ๋‹ค. Neuron SDK๋Š” ๋‹ค์Œ๊ณผ ๊ฐ™์€ ๊ธฐ๋Šฅ์„ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค:

  1. ์ฝ”๋“œ ํ•œ ์ค„๋งŒ ๋ณ€๊ฒฝํ•˜๋ฉด ํด๋ผ์šฐ๋“œ ์ถ”๋ก ๋ฅผ ์œ„ํ•ด TorchScript ๋ชจ๋ธ์„ ์ถ”์ ํ•˜๊ณ  ์ตœ์ ํ™”ํ•  ์ˆ˜ ์žˆ๋Š” ์‰ฌ์šด API
  2. ์ฆ‰์‹œ ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ์„ฑ๋Šฅ ์ตœ์ ํ™”๋กœ ๋น„์šฉ ํšจ์œจ ํ–ฅ์ƒ
  3. PyTorch ๋˜๋Š” TensorFlow๋กœ ๊ตฌ์ถ•๋œ Hugging Face transformers ๋ชจ๋ธ ์ง€์›

์‹œ์‚ฌ์ [[implications]]

BERT (Bidirectional Encoder Representations from Transformers) ์•„ํ‚คํ…์ฒ˜ ๋˜๋Š” ๊ทธ ๋ณ€ํ˜•์ธ distilBERT ๋ฐ roBERTa๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ํ•œ Transformers ๋ชจ๋ธ์€ ์ถ”์ถœ ๊ธฐ๋ฐ˜ ์งˆ์˜์‘๋‹ต, ์‹œํ€€์Šค ๋ถ„๋ฅ˜ ๋ฐ ํ† ํฐ ๋ถ„๋ฅ˜์™€ ๊ฐ™์€ ๋น„์ƒ์„ฑ ์ž‘์—… ์‹œ Inf1์—์„œ ์ตœ์ƒ์˜ ์„ฑ๋Šฅ์„ ๋ณด์ž…๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ํ…์ŠคํŠธ ์ƒ์„ฑ ์ž‘์—…๋„ AWS Neuron MarianMT ํŠœํ† ๋ฆฌ์–ผ์„ ๋”ฐ๋ผ Inf1์—์„œ ์‹คํ–‰๋˜๋„๋ก ์กฐ์ •ํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

Inferentia์—์„œ ๋ฐ”๋กœ ๋ณ€ํ™˜ํ•  ์ˆ˜ ์žˆ๋Š” ๋ชจ๋ธ์— ๋Œ€ํ•œ ์ž์„ธํ•œ ์ •๋ณด๋Š” Neuron ๋ฌธ์„œ์˜ Model Architecture Fit ์„น์…˜์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ข…์†์„ฑ[[dependencies]]

AWS Neuron์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ชจ๋ธ์„ ๋ณ€ํ™˜ํ•˜๋ ค๋ฉด Neuron SDK ํ™˜๊ฒฝ์ด ํ•„์š”ํ•ฉ๋‹ˆ๋‹ค. ์ด๋Š” AWS Deep Learning AMI์— ๋ฏธ๋ฆฌ ๊ตฌ์„ฑ๋˜์–ด ์žˆ์Šต๋‹ˆ๋‹ค.

AWS Neuron์œผ๋กœ ๋ชจ๋ธ ๋ณ€ํ™˜ํ•˜๊ธฐ[[converting-a-model-for-aws-neuron]]

BertModel์„ ์ถ”์ ํ•˜๋ ค๋ฉด, Python์—์„œ TorchScript ์‚ฌ์šฉํ•˜๊ธฐ์—์„œ์™€ ๋™์ผํ•œ ์ฝ”๋“œ๋ฅผ ์‚ฌ์šฉํ•ด์„œ AWS NEURON์šฉ ๋ชจ๋ธ์„ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค. torch.neuron ํ”„๋ ˆ์ž„์›Œํฌ ์ต์Šคํ…์…˜์„ ๊ฐ€์ ธ์™€ Python API๋ฅผ ํ†ตํ•ด Neuron SDK์˜ ๊ตฌ์„ฑ ์š”์†Œ์— ์ ‘๊ทผํ•ฉ๋‹ˆ๋‹ค:

from transformers import BertModel, BertTokenizer, BertConfig
import torch
import torch.neuron

๋‹ค์Œ ์ค„๋งŒ ์ˆ˜์ •ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค:

- torch.jit.trace(model, [tokens_tensor, segments_tensors])
+ torch.neuron.trace(model, [token_tensor, segments_tensors])

์ด๋กœ์จ Neuron SDK๊ฐ€ ๋ชจ๋ธ์„ ์ถ”์ ํ•˜๊ณ  Inf1 ์ธ์Šคํ„ด์Šค์— ์ตœ์ ํ™”ํ•  ์ˆ˜ ์žˆ๊ฒŒ ๋ฉ๋‹ˆ๋‹ค.

AWS Neuron SDK์˜ ๊ธฐ๋Šฅ, ๋„๊ตฌ, ์˜ˆ์ œ ํŠœํ† ๋ฆฌ์–ผ ๋ฐ ์ตœ์‹  ์—…๋ฐ์ดํŠธ์— ๋Œ€ํ•ด ์ž์„ธํžˆ ์•Œ์•„๋ณด๋ ค๋ฉด AWS NeuronSDK ๋ฌธ์„œ๋ฅผ ์ฐธ์กฐํ•˜์„ธ์š”.