KoichiYasuoka commited on
Commit
8644072
·
1 Parent(s): 33127b8

backward compatible

Browse files
Files changed (1) hide show
  1. ud.py +13 -5
ud.py CHANGED
@@ -14,17 +14,25 @@ class UniversalDependenciesPipeline(TokenClassificationPipeline):
14
  else:
15
  t.append((k,(s,e)))
16
  m=[(0,0)]+[j for i,j in t]+[(0,0)]
17
- r=list(super().preprocess(sentence=" ".join(i for i,j in t)))
18
- w=self.tokenizer.convert_ids_to_tokens(r[0]["input_ids"][0])
 
 
 
 
 
 
19
  if len(m)!=len(w):
20
  for i,j in enumerate(w):
21
  if j.endswith("@@"):
22
  s,e=m[i]
23
  m.insert(i+1,(s+len(j)-2,e))
24
  m[i]=(s,s+len(j)-2)
25
- r[0]["offset_mapping"]=torch.tensor([m]).to(self.device)
26
- r[0]["sentence"]=sentence
27
- return iter(r)
 
 
28
  def _forward(self,model_inputs):
29
  import torch
30
  v=model_inputs["input_ids"][0].tolist()
 
14
  else:
15
  t.append((k,(s,e)))
16
  m=[(0,0)]+[j for i,j in t]+[(0,0)]
17
+ r=super().preprocess(sentence=" ".join(i for i,j in t))
18
+ try:
19
+ f=True
20
+ k=r["input_ids"]
21
+ except:
22
+ r=list(r)[0]
23
+ f=False
24
+ w=self.tokenizer.convert_ids_to_tokens(r["input_ids"][0])
25
  if len(m)!=len(w):
26
  for i,j in enumerate(w):
27
  if j.endswith("@@"):
28
  s,e=m[i]
29
  m.insert(i+1,(s+len(j)-2,e))
30
  m[i]=(s,s+len(j)-2)
31
+ r["offset_mapping"]=torch.tensor([m]).to(self.device)
32
+ r["sentence"]=sentence
33
+ if f:
34
+ return r
35
+ return iter([r])
36
  def _forward(self,model_inputs):
37
  import torch
38
  v=model_inputs["input_ids"][0].tolist()