Update README.md
Browse files
README.md
CHANGED
@@ -1,11 +1,84 @@
|
|
1 |
---
|
2 |
-
license: mit
|
3 |
-
datasets:
|
4 |
-
- google/code_x_glue_tc_nl_code_search_adv
|
5 |
-
base_model:
|
6 |
-
- Shuu12121/CodeMorph-BERT
|
7 |
-
pipeline_tag: fill-mask
|
8 |
-
tags:
|
9 |
-
- code
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
license: mit
|
3 |
+
datasets:
|
4 |
+
- google/code_x_glue_tc_nl_code_search_adv
|
5 |
+
base_model:
|
6 |
+
- Shuu12121/CodeMorph-BERT
|
7 |
+
pipeline_tag: fill-mask
|
8 |
+
tags:
|
9 |
+
- code
|
10 |
+
|
11 |
+
description: |
|
12 |
+
このモデルは `Shuu12121/CodeMorph-BERT` をベースに、`google/code_x_glue_tc_nl_code_search_adv` データセットを用いて追加トレーニングを行いました。
|
13 |
+
これにより、コード検索精度の向上を実現しました。
|
14 |
+
|
15 |
+
### 🔬 **実験結果**
|
16 |
+
|
17 |
+
#### **CodeSearchNet (候補プールサイズ: 100) における評価**
|
18 |
+
| Metric | CodeMorph-BERT | CodeMorph-BERTv2 | Microsoft CodeBERT |
|
19 |
+
|---------------|---------------|------------------|--------------------|
|
20 |
+
| **MRR** | 0.6678 | 0.6607 | 0.5598 |
|
21 |
+
| **MAP** | 0.6678 | 0.6607 | 0.5598 |
|
22 |
+
| **R-Precision** | 0.5650 | 0.5510 | 0.4650 |
|
23 |
+
| **Recall@1** | 0.5650 | 0.5510 | 0.4650 |
|
24 |
+
| **Recall@5** | 0.7970 | 0.7970 | 0.6490 |
|
25 |
+
| **Recall@10** | 0.8600 | 0.8630 | 0.7410 |
|
26 |
+
| **Recall@50** | 0.9770 | 0.9790 | 0.9640 |
|
27 |
+
| **Recall@100** | 1.0000 | 1.0000 | 1.0000 |
|
28 |
+
| **Precision@1** | 0.5650 | 0.5510 | 0.4650 |
|
29 |
+
| **NDCG@10** | 0.7091 | 0.7050 | 0.5936 |
|
30 |
+
|
31 |
+
#### **google/code_x_glue_tc_nl_code_search_adv データセットでの評価**
|
32 |
+
| Metric | CodeMorph-BERT | CodeMorph-BERTv2 | Microsoft CodeBERT |
|
33 |
+
|---------------|---------------|------------------|--------------------|
|
34 |
+
| **MRR** | 0.3023 | 0.3154 | 0.3562 |
|
35 |
+
| **Recall@1** | 0.2035 | 0.2185 | 0.2527 |
|
36 |
+
| **Recall@5** | 0.3909 | 0.4021 | 0.4490 |
|
37 |
+
| **Recall@10** | 0.4936 | 0.5046 | 0.5622 |
|
38 |
+
| **Recall@50** | 0.8134 | 0.8231 | 0.8995 |
|
39 |
+
| **Recall@100** | 1.0000 | 1.0000 | 1.0000 |
|
40 |
+
| **Precision@1** | 0.2035 | 0.2185 | 0.2527 |
|
41 |
+
| **NDCG@10** | 0.3344 | 0.3469 | 0.3912 |
|
42 |
+
|
43 |
+
---
|
44 |
+
## 📝 **CodeMorph-BERT vs CodeMorph-BERTv2 の比較**
|
45 |
+
- **CodeSearchNet におけるパフォーマンス**
|
46 |
+
- `CodeMorph-BERTv2` は `CodeMorph-BERT` とほぼ同等の精度を維持しており、リコールとNDCGにおいてわずかに改善が見られます。
|
47 |
+
- `Recall@10` や `Recall@50` は v2 のほうがわずかに高く、全体的に検索精度が向上しています。
|
48 |
+
|
49 |
+
- **google/code_x_glue_tc_nl_code_search_adv におけるパフォーマンス**
|
50 |
+
- `CodeMorph-BERTv2` は `CodeMorph-BERT` より **MRR、Recall@1、Recall@5、Recall@10** の値が向上しており、検索精度の改善が見られます。
|
51 |
+
- `Precision@1` も向上しており、上位検索結果の精度が向上していることが分かります。
|
52 |
+
|
53 |
+
- **総合的な考察**
|
54 |
+
- `CodeMorph-BERTv2` は `CodeMorph-BERT` と比べて **コード検索の再現率 (Recall) を向上** させており、より優れた検索結果を返すことが可能になりました。
|
55 |
+
- ただし、`Microsoft CodeBERT` と比較すると、`google/code_x_glue_tc_nl_code_search_adv` のデータセットではまだ改善の余地があり、さらなるチューニングの余地が残っています。
|
56 |
+
|
57 |
+
---
|
58 |
+
## 💡 **使用方法**
|
59 |
+
```python
|
60 |
+
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
61 |
+
import torch
|
62 |
+
|
63 |
+
# モデルとトークナイザーのロード
|
64 |
+
model_name = "Shuu12121/CodeMorph-BERTv2"
|
65 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
66 |
+
model = AutoModelForMaskedLM.from_pretrained(model_name)
|
67 |
+
|
68 |
+
# 入力コード([MASK] を含む)
|
69 |
+
text = "def add(a, b): return a [MASK] b"
|
70 |
+
|
71 |
+
# トークナイズ
|
72 |
+
inputs = tokenizer(text, return_tensors="pt")
|
73 |
+
|
74 |
+
# 推論実行
|
75 |
+
with torch.no_grad():
|
76 |
+
outputs = model(**inputs)
|
77 |
+
logits = outputs.logits
|
78 |
+
|
79 |
+
# マスク位置のトークンを予測
|
80 |
+
mask_token_index = torch.where(inputs.input_ids == tokenizer.mask_token_id)[1]
|
81 |
+
predicted_token_id = logits[0, mask_token_index, :].argmax(axis=-1)
|
82 |
+
predicted_token = tokenizer.decode(predicted_token_id)
|
83 |
+
|
84 |
+
print("予測されたトークン:", predicted_token)
|