|
--- |
|
license: apache-2.0 |
|
--- |
|
|
|
์ฌ์ฉ์์ |
|
|
|
```python |
|
import onnxruntime as ort |
|
import numpy as np |
|
from transformers import MobileBertTokenizer |
|
|
|
# ๋ชจ๋ธ ๋ฐ ํ ํฌ๋์ด์ ๊ฒฝ๋ก ์ค์ |
|
model_path = r'C:\NEW_tinybert\AI\tinybert_model.onnx' # ONNX ๋ชจ๋ธ ๊ฒฝ๋ก |
|
tokenizer_path = r'C:\NEW_distilbert\AI' # ๋ก์ปฌ ํ ํฌ๋์ด์ ๊ฒฝ๋ก |
|
|
|
# ONNX ๋ชจ๋ธ ์ธ์
์ด๊ธฐํ |
|
ort_session = ort.InferenceSession(model_path) |
|
|
|
# MobileBERT ํ ํฌ๋์ด์ ๋ก๋ |
|
tokenizer = MobileBertTokenizer.from_pretrained(tokenizer_path) |
|
|
|
# ํ
์คํธ ๋ถ๋ฅ ํจ์ |
|
def test_model(text): |
|
""" |
|
์
๋ ฅ๋ ํ
์คํธ๋ฅผ ONNX ๋ชจ๋ธ์ ์ฌ์ฉํด ๋ถ๋ฅํ๋ ํจ์ |
|
Args: |
|
text (str): ๋ถ์ํ ํ
์คํธ |
|
Returns: |
|
str: ์์ธก ๊ฒฐ๊ณผ ๋ฉ์์ง |
|
""" |
|
# ์
๋ ฅ ํ
์คํธ๋ฅผ ํ ํฐํ ๋ฐ ONNX ๋ชจ๋ธ ์
๋ ฅ ํ์์ผ๋ก ๋ณํ |
|
inputs = tokenizer( |
|
text, |
|
padding="max_length", # ์
๋ ฅ ๊ธธ์ด๋ฅผ 128๋ก ๊ณ ์ |
|
truncation=True, # ๊ธด ํ
์คํธ๋ ์๋ผ๋ |
|
max_length=128, # ์ต๋ ํ ํฐ ๊ธธ์ด |
|
return_tensors="np" # NumPy ๋ฐฐ์ด ํ์์ผ๋ก ๋ฐํ |
|
) |
|
|
|
# NumPy ๋ฐฐ์ด์ int64๋ก ๋ณํ |
|
input_ids = inputs["input_ids"].astype(np.int64) |
|
attention_mask = inputs["attention_mask"].astype(np.int64) |
|
|
|
# ONNX ๋ชจ๋ธ ์
๋ ฅ ์ค๋น |
|
ort_inputs = { |
|
"input_ids": input_ids, |
|
"attention_mask": attention_mask |
|
} |
|
|
|
# ONNX ๋ชจ๋ธ ์ถ๋ก ์คํ |
|
outputs = ort_session.run(None, ort_inputs) |
|
logits = outputs[0] # ๋ชจ๋ธ ์ถ๋ ฅ (๋ก์ง ๊ฐ) |
|
|
|
# ๋ก์ง ๊ฐ์ ํ๋ฅ ๋ก ๋ณํ ๋ฐ ํด๋์ค ์์ธก |
|
predicted_class = np.argmax(logits, axis=1).item() |
|
|
|
# ๊ฒฐ๊ณผ ๋ฐํ |
|
return "๋ก๋งจ์ค ์ค์บ ์ผ ๊ฐ๋ฅ์ฑ ์์" if predicted_class == 1 else "๋ก๋งจ์ค ์ค์บ ์ด ์๋" |
|
|
|
# ํ
์คํธํ ๋ํ ๋ด์ฉ |
|
test_texts = [ |
|
"๋ ์๋ง ์๋?", |
|
"์ ๋ ๊ธ์ต ์ํ์ ์๊ฐํ๋ ์ฌ๋์
๋๋ค. ํฌ์ํ๋ฉด ์ด์ต์ด ํฝ๋๋ค.", |
|
"๋ด ๋ณด์ง๊ฐ ๋ฌ์์ฌ๋์ด", |
|
"๋ด ๊ฐ์ด ๋ง์ง๋??" |
|
] |
|
|
|
# ๊ฐ ํ
์คํธ ํ
์คํธ์ ๋ํด ๊ฒฐ๊ณผ ์ถ๋ ฅ |
|
for text in test_texts: |
|
result = test_model(text) |
|
print(f"์
๋ ฅ: {text} => ๊ฒฐ๊ณผ: {result}") |
|
|
|
``` |