tinybert_model / README.md
gihakkk's picture
Update README.md
6e6b6ef verified
---
license: apache-2.0
---
์‚ฌ์šฉ์˜ˆ์‹œ
```python
import onnxruntime as ort
import numpy as np
from transformers import MobileBertTokenizer
# ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๊ฒฝ๋กœ ์„ค์ •
model_path = r'C:\NEW_tinybert\AI\tinybert_model.onnx' # ONNX ๋ชจ๋ธ ๊ฒฝ๋กœ
tokenizer_path = r'C:\NEW_distilbert\AI' # ๋กœ์ปฌ ํ† ํฌ๋‚˜์ด์ € ๊ฒฝ๋กœ
# ONNX ๋ชจ๋ธ ์„ธ์…˜ ์ดˆ๊ธฐํ™”
ort_session = ort.InferenceSession(model_path)
# MobileBERT ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
tokenizer = MobileBertTokenizer.from_pretrained(tokenizer_path)
# ํ…์ŠคํŠธ ๋ถ„๋ฅ˜ ํ•จ์ˆ˜
def test_model(text):
"""
์ž…๋ ฅ๋œ ํ…์ŠคํŠธ๋ฅผ ONNX ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•ด ๋ถ„๋ฅ˜ํ•˜๋Š” ํ•จ์ˆ˜
Args:
text (str): ๋ถ„์„ํ•  ํ…์ŠคํŠธ
Returns:
str: ์˜ˆ์ธก ๊ฒฐ๊ณผ ๋ฉ”์‹œ์ง€
"""
# ์ž…๋ ฅ ํ…์ŠคํŠธ๋ฅผ ํ† ํฐํ™” ๋ฐ ONNX ๋ชจ๋ธ ์ž…๋ ฅ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜
inputs = tokenizer(
text,
padding="max_length", # ์ž…๋ ฅ ๊ธธ์ด๋ฅผ 128๋กœ ๊ณ ์ •
truncation=True, # ๊ธด ํ…์ŠคํŠธ๋Š” ์ž˜๋ผ๋ƒ„
max_length=128, # ์ตœ๋Œ€ ํ† ํฐ ๊ธธ์ด
return_tensors="np" # NumPy ๋ฐฐ์—ด ํ˜•์‹์œผ๋กœ ๋ฐ˜ํ™˜
)
# NumPy ๋ฐฐ์—ด์„ int64๋กœ ๋ณ€ํ™˜
input_ids = inputs["input_ids"].astype(np.int64)
attention_mask = inputs["attention_mask"].astype(np.int64)
# ONNX ๋ชจ๋ธ ์ž…๋ ฅ ์ค€๋น„
ort_inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask
}
# ONNX ๋ชจ๋ธ ์ถ”๋ก  ์‹คํ–‰
outputs = ort_session.run(None, ort_inputs)
logits = outputs[0] # ๋ชจ๋ธ ์ถœ๋ ฅ (๋กœ์ง“ ๊ฐ’)
# ๋กœ์ง“ ๊ฐ’์„ ํ™•๋ฅ ๋กœ ๋ณ€ํ™˜ ๋ฐ ํด๋ž˜์Šค ์˜ˆ์ธก
predicted_class = np.argmax(logits, axis=1).item()
# ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
return "๋กœ๋งจ์Šค ์Šค์บ ์ผ ๊ฐ€๋Šฅ์„ฑ ์žˆ์Œ" if predicted_class == 1 else "๋กœ๋งจ์Šค ์Šค์บ ์ด ์•„๋‹˜"
# ํ…Œ์ŠคํŠธํ•  ๋Œ€ํ™” ๋‚ด์šฉ
test_texts = [
"๋„ˆ ์—„๋งˆ ์—†๋ƒ?",
"์ €๋Š” ๊ธˆ์œต ์ƒํ’ˆ์„ ์†Œ๊ฐœํ•˜๋Š” ์‚ฌ๋žŒ์ž…๋‹ˆ๋‹ค. ํˆฌ์žํ•˜๋ฉด ์ด์ต์ด ํฝ๋‹ˆ๋‹ค.",
"๋‚ด ๋ณด์ง€๊ฐ€ ๋‹ฌ์•„์˜ฌ๋ž์–ด",
"๋‚ด ๊ฐ€์Šด ๋งŒ์งˆ๋ž˜??"
]
# ๊ฐ ํ…Œ์ŠคํŠธ ํ…์ŠคํŠธ์— ๋Œ€ํ•ด ๊ฒฐ๊ณผ ์ถœ๋ ฅ
for text in test_texts:
result = test_model(text)
print(f"์ž…๋ ฅ: {text} => ๊ฒฐ๊ณผ: {result}")
```