File size: 2,190 Bytes
6e6b6ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
---
license: apache-2.0
---
์ฌ์ฉ์์
```python
import onnxruntime as ort
import numpy as np
from transformers import MobileBertTokenizer
# ๋ชจ๋ธ ๋ฐ ํ ํฌ๋์ด์ ๊ฒฝ๋ก ์ค์
model_path = r'C:\NEW_tinybert\AI\tinybert_model.onnx' # ONNX ๋ชจ๋ธ ๊ฒฝ๋ก
tokenizer_path = r'C:\NEW_distilbert\AI' # ๋ก์ปฌ ํ ํฌ๋์ด์ ๊ฒฝ๋ก
# ONNX ๋ชจ๋ธ ์ธ์
์ด๊ธฐํ
ort_session = ort.InferenceSession(model_path)
# MobileBERT ํ ํฌ๋์ด์ ๋ก๋
tokenizer = MobileBertTokenizer.from_pretrained(tokenizer_path)
# ํ
์คํธ ๋ถ๋ฅ ํจ์
def test_model(text):
"""
์
๋ ฅ๋ ํ
์คํธ๋ฅผ ONNX ๋ชจ๋ธ์ ์ฌ์ฉํด ๋ถ๋ฅํ๋ ํจ์
Args:
text (str): ๋ถ์ํ ํ
์คํธ
Returns:
str: ์์ธก ๊ฒฐ๊ณผ ๋ฉ์์ง
"""
# ์
๋ ฅ ํ
์คํธ๋ฅผ ํ ํฐํ ๋ฐ ONNX ๋ชจ๋ธ ์
๋ ฅ ํ์์ผ๋ก ๋ณํ
inputs = tokenizer(
text,
padding="max_length", # ์
๋ ฅ ๊ธธ์ด๋ฅผ 128๋ก ๊ณ ์
truncation=True, # ๊ธด ํ
์คํธ๋ ์๋ผ๋
max_length=128, # ์ต๋ ํ ํฐ ๊ธธ์ด
return_tensors="np" # NumPy ๋ฐฐ์ด ํ์์ผ๋ก ๋ฐํ
)
# NumPy ๋ฐฐ์ด์ int64๋ก ๋ณํ
input_ids = inputs["input_ids"].astype(np.int64)
attention_mask = inputs["attention_mask"].astype(np.int64)
# ONNX ๋ชจ๋ธ ์
๋ ฅ ์ค๋น
ort_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask
}
# ONNX ๋ชจ๋ธ ์ถ๋ก ์คํ
outputs = ort_session.run(None, ort_inputs)
logits = outputs[0] # ๋ชจ๋ธ ์ถ๋ ฅ (๋ก์ง ๊ฐ)
# ๋ก์ง ๊ฐ์ ํ๋ฅ ๋ก ๋ณํ ๋ฐ ํด๋์ค ์์ธก
predicted_class = np.argmax(logits, axis=1).item()
# ๊ฒฐ๊ณผ ๋ฐํ
return "๋ก๋งจ์ค ์ค์บ ์ผ ๊ฐ๋ฅ์ฑ ์์" if predicted_class == 1 else "๋ก๋งจ์ค ์ค์บ ์ด ์๋"
# ํ
์คํธํ ๋ํ ๋ด์ฉ
test_texts = [
"๋ ์๋ง ์๋?",
"์ ๋ ๊ธ์ต ์ํ์ ์๊ฐํ๋ ์ฌ๋์
๋๋ค. ํฌ์ํ๋ฉด ์ด์ต์ด ํฝ๋๋ค.",
"๋ด ๋ณด์ง๊ฐ ๋ฌ์์ฌ๋์ด",
"๋ด ๊ฐ์ด ๋ง์ง๋??"
]
# ๊ฐ ํ
์คํธ ํ
์คํธ์ ๋ํด ๊ฒฐ๊ณผ ์ถ๋ ฅ
for text in test_texts:
result = test_model(text)
print(f"์
๋ ฅ: {text} => ๊ฒฐ๊ณผ: {result}")
``` |