jmg2016 commited on
Commit
6ce5455
·
1 Parent(s): 361a4ea

initial commit of onnx file and conversion script

Browse files
Files changed (3) hide show
  1. README.md +9 -0
  2. splade_pp_en_v2.onnx +3 -0
  3. splade_pp_en_v2_to_onnx.py +76 -0
README.md ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # ONNX model for Splade_PP_en_v2
2
+
3
+ See [https://huggingface.co/prithivida/Splade_PP_en_v2](https://huggingface.co/prithivida/Splade_PP_en_v2)
4
+
5
+ This is just a script for onnx conversion, and an onnx model, with an output format that is compatible with the (anserini)[https://github.com/castorini/anserini] SparseEncoder implementations.
6
+
7
+ ```
8
+ python splade_pp_en_v2_to_onnx.py splade_pp_en_v2.onnx
9
+ ```
splade_pp_en_v2.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e85eac8d47460ffd7a9cc0199228b540081e5bc2c82b6a8963e3d0b6103e4ed
3
+ size 532158694
splade_pp_en_v2_to_onnx.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ from transformers import AutoModelForMaskedLM,AutoTokenizer # type: ignore
4
+
5
+ # Convert prithivida/Splade_PP_en_v2 to onnx.
6
+ # based on this info:
7
+ # - https://github.com/naver/splade/issues/47
8
+ # - https://github.com/castorini/anserini/blob/master/docs/onnx-conversion.md
9
+
10
+
11
+ class TransformerRep(torch.nn.Module):
12
+ def __init__(self):
13
+ super().__init__()
14
+ self.model = AutoModelForMaskedLM.from_pretrained('prithivida/Splade_PP_en_v2')
15
+ self.model.eval() # type: ignore
16
+ self.fp16 = True
17
+
18
+ def encode(self, input_ids, token_type_ids, attention_mask):
19
+ return self.model(
20
+ input_ids=input_ids,
21
+ token_type_ids=token_type_ids,
22
+ attention_mask=attention_mask
23
+ )[0]
24
+
25
+
26
+
27
+ class SpladeModel(torch.nn.Module):
28
+ def __init__(self):
29
+ super().__init__()
30
+ self.model = TransformerRep()
31
+ self.agg = "max"
32
+ self.model.eval()
33
+
34
+ def forward(self, input_ids,token_type_ids, attention_mask):
35
+ with torch.cuda.amp.autocast(): # type: ignore
36
+ with torch.no_grad():
37
+ lm_logits = self.model.encode(input_ids,token_type_ids, attention_mask)[0]
38
+ vec, _ = torch.max(torch.log(1 + torch.relu(lm_logits)) * attention_mask.unsqueeze(-1), dim=1)
39
+ indices = vec.nonzero().squeeze()
40
+ weights = vec.squeeze()[indices]
41
+ return indices[:,1], weights[:,1]
42
+
43
+
44
+ if __name__ == '__main__':
45
+ if len(sys.argv) != 2:
46
+ print('Usage:', sys.argv[0], '<output-file-name>')
47
+ sys.exit(1)
48
+
49
+ # Convert the model to TorchScript
50
+ model = SpladeModel()
51
+
52
+ input_ids = torch.randint(1,100, size=(1,50))
53
+ token_type_ids = torch.full((1,50), 0)
54
+ attention_mask = torch.full((1,50), 1)
55
+ traced_model = torch.jit.trace(model, (input_ids, token_type_ids, attention_mask))
56
+
57
+
58
+ dyn_axis = {
59
+ 'input_ids': {0: 'batch_size', 1: 'sequence'},
60
+ 'attention_mask': {0: 'batch_size', 1: 'sequence'},
61
+ 'token_type_ids': {0: 'batch_size', 1: 'sequence'},
62
+ 'output_idx': {0: 'batch_size', 1: 'sequence'},
63
+ 'output_weights': {0: 'batch_size', 1: 'sequence'}
64
+ }
65
+
66
+ onnx_model = torch.onnx.export(
67
+ traced_model,
68
+ (input_ids, token_type_ids, attention_mask), # type: ignore
69
+ f=sys.argv[1],
70
+ input_names=['input_ids','token_type_ids', 'attention_mask'],
71
+ output_names=['output_idx', 'output_weights'],
72
+ dynamic_axes=dyn_axis,
73
+ do_constant_folding=True,
74
+ opset_version=15,
75
+ verbose=False,
76
+ )