Commit
·
a13cd13
1
Parent(s):
d962907
algorithm improved
Browse files
ud.py
CHANGED
@@ -19,9 +19,10 @@ class UniversalDependenciesPipeline(TokenClassificationPipeline):
|
|
19 |
f=True
|
20 |
k=r["input_ids"]
|
21 |
except:
|
22 |
-
r=list(r)[0]
|
23 |
f=False
|
24 |
-
|
|
|
|
|
25 |
if len(m)!=len(w):
|
26 |
for i,j in enumerate(w):
|
27 |
if j.endswith("@@"):
|
@@ -30,9 +31,7 @@ class UniversalDependenciesPipeline(TokenClassificationPipeline):
|
|
30 |
m[i]=(s,s+len(j)-2)
|
31 |
r["offset_mapping"]=torch.tensor([m]).to(self.device)
|
32 |
r["sentence"]=sentence
|
33 |
-
if f
|
34 |
-
return r
|
35 |
-
return iter([r])
|
36 |
def _forward(self,model_inputs):
|
37 |
import torch
|
38 |
v=model_inputs["input_ids"][0].tolist()
|
|
|
19 |
f=True
|
20 |
k=r["input_ids"]
|
21 |
except:
|
|
|
22 |
f=False
|
23 |
+
r=list(r)[0]
|
24 |
+
k=r["input_ids"]
|
25 |
+
w=self.tokenizer.convert_ids_to_tokens(k[0])
|
26 |
if len(m)!=len(w):
|
27 |
for i,j in enumerate(w):
|
28 |
if j.endswith("@@"):
|
|
|
31 |
m[i]=(s,s+len(j)-2)
|
32 |
r["offset_mapping"]=torch.tensor([m]).to(self.device)
|
33 |
r["sentence"]=sentence
|
34 |
+
return r if f else iter([r])
|
|
|
|
|
35 |
def _forward(self,model_inputs):
|
36 |
import torch
|
37 |
v=model_inputs["input_ids"][0].tolist()
|