KoichiYasuoka commited on
Commit
a13cd13
·
1 Parent(s): d962907

algorithm improved

Browse files
Files changed (1) hide show
  1. ud.py +4 -5
ud.py CHANGED
@@ -19,9 +19,10 @@ class UniversalDependenciesPipeline(TokenClassificationPipeline):
19
  f=True
20
  k=r["input_ids"]
21
  except:
22
- r=list(r)[0]
23
  f=False
24
- w=self.tokenizer.convert_ids_to_tokens(r["input_ids"][0])
 
 
25
  if len(m)!=len(w):
26
  for i,j in enumerate(w):
27
  if j.endswith("@@"):
@@ -30,9 +31,7 @@ class UniversalDependenciesPipeline(TokenClassificationPipeline):
30
  m[i]=(s,s+len(j)-2)
31
  r["offset_mapping"]=torch.tensor([m]).to(self.device)
32
  r["sentence"]=sentence
33
- if f:
34
- return r
35
- return iter([r])
36
  def _forward(self,model_inputs):
37
  import torch
38
  v=model_inputs["input_ids"][0].tolist()
 
19
  f=True
20
  k=r["input_ids"]
21
  except:
 
22
  f=False
23
+ r=list(r)[0]
24
+ k=r["input_ids"]
25
+ w=self.tokenizer.convert_ids_to_tokens(k[0])
26
  if len(m)!=len(w):
27
  for i,j in enumerate(w):
28
  if j.endswith("@@"):
 
31
  m[i]=(s,s+len(j)-2)
32
  r["offset_mapping"]=torch.tensor([m]).to(self.device)
33
  r["sentence"]=sentence
34
+ return r if f else iter([r])
 
 
35
  def _forward(self,model_inputs):
36
  import torch
37
  v=model_inputs["input_ids"][0].tolist()