gihakkk commited on
Commit
6e6b6ef
ยท
verified ยท
1 Parent(s): 113704b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +73 -3
README.md CHANGED
@@ -1,3 +1,73 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ ---
4
+
5
+ ์‚ฌ์šฉ์˜ˆ์‹œ
6
+
7
+ ```python
8
+ import onnxruntime as ort
9
+ import numpy as np
10
+ from transformers import MobileBertTokenizer
11
+
12
+ # ๋ชจ๋ธ ๋ฐ ํ† ํฌ๋‚˜์ด์ € ๊ฒฝ๋กœ ์„ค์ •
13
+ model_path = r'C:\NEW_tinybert\AI\tinybert_model.onnx' # ONNX ๋ชจ๋ธ ๊ฒฝ๋กœ
14
+ tokenizer_path = r'C:\NEW_distilbert\AI' # ๋กœ์ปฌ ํ† ํฌ๋‚˜์ด์ € ๊ฒฝ๋กœ
15
+
16
+ # ONNX ๋ชจ๋ธ ์„ธ์…˜ ์ดˆ๊ธฐํ™”
17
+ ort_session = ort.InferenceSession(model_path)
18
+
19
+ # MobileBERT ํ† ํฌ๋‚˜์ด์ € ๋กœ๋“œ
20
+ tokenizer = MobileBertTokenizer.from_pretrained(tokenizer_path)
21
+
22
+ # ํ…์ŠคํŠธ ๋ถ„๋ฅ˜ ํ•จ์ˆ˜
23
+ def test_model(text):
24
+ """
25
+ ์ž…๋ ฅ๋œ ํ…์ŠคํŠธ๋ฅผ ONNX ๋ชจ๋ธ์„ ์‚ฌ์šฉํ•ด ๋ถ„๋ฅ˜ํ•˜๋Š” ํ•จ์ˆ˜
26
+ Args:
27
+ text (str): ๋ถ„์„ํ•  ํ…์ŠคํŠธ
28
+ Returns:
29
+ str: ์˜ˆ์ธก ๊ฒฐ๊ณผ ๋ฉ”์‹œ์ง€
30
+ """
31
+ # ์ž…๋ ฅ ํ…์ŠคํŠธ๋ฅผ ํ† ํฐํ™” ๋ฐ ONNX ๋ชจ๋ธ ์ž…๋ ฅ ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜
32
+ inputs = tokenizer(
33
+ text,
34
+ padding="max_length", # ์ž…๋ ฅ ๊ธธ์ด๋ฅผ 128๋กœ ๊ณ ์ •
35
+ truncation=True, # ๊ธด ํ…์ŠคํŠธ๋Š” ์ž˜๋ผ๋ƒ„
36
+ max_length=128, # ์ตœ๋Œ€ ํ† ํฐ ๊ธธ์ด
37
+ return_tensors="np" # NumPy ๋ฐฐ์—ด ํ˜•์‹์œผ๋กœ ๋ฐ˜ํ™˜
38
+ )
39
+
40
+ # NumPy ๋ฐฐ์—ด์„ int64๋กœ ๋ณ€ํ™˜
41
+ input_ids = inputs["input_ids"].astype(np.int64)
42
+ attention_mask = inputs["attention_mask"].astype(np.int64)
43
+
44
+ # ONNX ๋ชจ๋ธ ์ž…๋ ฅ ์ค€๋น„
45
+ ort_inputs = {
46
+ "input_ids": input_ids,
47
+ "attention_mask": attention_mask
48
+ }
49
+
50
+ # ONNX ๋ชจ๋ธ ์ถ”๋ก  ์‹คํ–‰
51
+ outputs = ort_session.run(None, ort_inputs)
52
+ logits = outputs[0] # ๋ชจ๋ธ ์ถœ๋ ฅ (๋กœ์ง“ ๊ฐ’)
53
+
54
+ # ๋กœ์ง“ ๊ฐ’์„ ํ™•๋ฅ ๋กœ ๋ณ€ํ™˜ ๋ฐ ํด๋ž˜์Šค ์˜ˆ์ธก
55
+ predicted_class = np.argmax(logits, axis=1).item()
56
+
57
+ # ๊ฒฐ๊ณผ ๋ฐ˜ํ™˜
58
+ return "๋กœ๋งจ์Šค ์Šค์บ ์ผ ๊ฐ€๋Šฅ์„ฑ ์žˆ์Œ" if predicted_class == 1 else "๋กœ๋งจ์Šค ์Šค์บ ์ด ์•„๋‹˜"
59
+
60
+ # ํ…Œ์ŠคํŠธํ•  ๋Œ€ํ™” ๋‚ด์šฉ
61
+ test_texts = [
62
+ "๋„ˆ ์—„๋งˆ ์—†๋ƒ?",
63
+ "์ €๋Š” ๊ธˆ์œต ์ƒํ’ˆ์„ ์†Œ๊ฐœํ•˜๋Š” ์‚ฌ๋žŒ์ž…๋‹ˆ๋‹ค. ํˆฌ์žํ•˜๋ฉด ์ด์ต์ด ํฝ๋‹ˆ๋‹ค.",
64
+ "๋‚ด ๋ณด์ง€๊ฐ€ ๋‹ฌ์•„์˜ฌ๋ž์–ด",
65
+ "๋‚ด ๊ฐ€์Šด ๋งŒ์งˆ๋ž˜??"
66
+ ]
67
+
68
+ # ๊ฐ ํ…Œ์ŠคํŠธ ํ…์ŠคํŠธ์— ๋Œ€ํ•ด ๊ฒฐ๊ณผ ์ถœ๋ ฅ
69
+ for text in test_texts:
70
+ result = test_model(text)
71
+ print(f"์ž…๋ ฅ: {text} => ๊ฒฐ๊ณผ: {result}")
72
+
73
+ ```