initial commit of onnx file and conversion script
Browse files- README.md +9 -0
- splade_pp_en_v2.onnx +3 -0
- splade_pp_en_v2_to_onnx.py +76 -0
README.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ONNX model for Splade_PP_en_v2
|
2 |
+
|
3 |
+
See [https://huggingface.co/prithivida/Splade_PP_en_v2](https://huggingface.co/prithivida/Splade_PP_en_v2)
|
4 |
+
|
5 |
+
This is just a script for onnx conversion, and an onnx model, with an output format that is compatible with the (anserini)[https://github.com/castorini/anserini] SparseEncoder implementations.
|
6 |
+
|
7 |
+
```
|
8 |
+
python splade_pp_en_v2_to_onnx.py splade_pp_en_v2.onnx
|
9 |
+
```
|
splade_pp_en_v2.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:7e85eac8d47460ffd7a9cc0199228b540081e5bc2c82b6a8963e3d0b6103e4ed
|
3 |
+
size 532158694
|
splade_pp_en_v2_to_onnx.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
from transformers import AutoModelForMaskedLM,AutoTokenizer # type: ignore
|
4 |
+
|
5 |
+
# Convert prithivida/Splade_PP_en_v2 to onnx.
|
6 |
+
# based on this info:
|
7 |
+
# - https://github.com/naver/splade/issues/47
|
8 |
+
# - https://github.com/castorini/anserini/blob/master/docs/onnx-conversion.md
|
9 |
+
|
10 |
+
|
11 |
+
class TransformerRep(torch.nn.Module):
|
12 |
+
def __init__(self):
|
13 |
+
super().__init__()
|
14 |
+
self.model = AutoModelForMaskedLM.from_pretrained('prithivida/Splade_PP_en_v2')
|
15 |
+
self.model.eval() # type: ignore
|
16 |
+
self.fp16 = True
|
17 |
+
|
18 |
+
def encode(self, input_ids, token_type_ids, attention_mask):
|
19 |
+
return self.model(
|
20 |
+
input_ids=input_ids,
|
21 |
+
token_type_ids=token_type_ids,
|
22 |
+
attention_mask=attention_mask
|
23 |
+
)[0]
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
class SpladeModel(torch.nn.Module):
|
28 |
+
def __init__(self):
|
29 |
+
super().__init__()
|
30 |
+
self.model = TransformerRep()
|
31 |
+
self.agg = "max"
|
32 |
+
self.model.eval()
|
33 |
+
|
34 |
+
def forward(self, input_ids,token_type_ids, attention_mask):
|
35 |
+
with torch.cuda.amp.autocast(): # type: ignore
|
36 |
+
with torch.no_grad():
|
37 |
+
lm_logits = self.model.encode(input_ids,token_type_ids, attention_mask)[0]
|
38 |
+
vec, _ = torch.max(torch.log(1 + torch.relu(lm_logits)) * attention_mask.unsqueeze(-1), dim=1)
|
39 |
+
indices = vec.nonzero().squeeze()
|
40 |
+
weights = vec.squeeze()[indices]
|
41 |
+
return indices[:,1], weights[:,1]
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == '__main__':
|
45 |
+
if len(sys.argv) != 2:
|
46 |
+
print('Usage:', sys.argv[0], '<output-file-name>')
|
47 |
+
sys.exit(1)
|
48 |
+
|
49 |
+
# Convert the model to TorchScript
|
50 |
+
model = SpladeModel()
|
51 |
+
|
52 |
+
input_ids = torch.randint(1,100, size=(1,50))
|
53 |
+
token_type_ids = torch.full((1,50), 0)
|
54 |
+
attention_mask = torch.full((1,50), 1)
|
55 |
+
traced_model = torch.jit.trace(model, (input_ids, token_type_ids, attention_mask))
|
56 |
+
|
57 |
+
|
58 |
+
dyn_axis = {
|
59 |
+
'input_ids': {0: 'batch_size', 1: 'sequence'},
|
60 |
+
'attention_mask': {0: 'batch_size', 1: 'sequence'},
|
61 |
+
'token_type_ids': {0: 'batch_size', 1: 'sequence'},
|
62 |
+
'output_idx': {0: 'batch_size', 1: 'sequence'},
|
63 |
+
'output_weights': {0: 'batch_size', 1: 'sequence'}
|
64 |
+
}
|
65 |
+
|
66 |
+
onnx_model = torch.onnx.export(
|
67 |
+
traced_model,
|
68 |
+
(input_ids, token_type_ids, attention_mask), # type: ignore
|
69 |
+
f=sys.argv[1],
|
70 |
+
input_names=['input_ids','token_type_ids', 'attention_mask'],
|
71 |
+
output_names=['output_idx', 'output_weights'],
|
72 |
+
dynamic_axes=dyn_axis,
|
73 |
+
do_constant_folding=True,
|
74 |
+
opset_version=15,
|
75 |
+
verbose=False,
|
76 |
+
)
|