File size: 4,175 Bytes
33127b8
cfe253c
 
 
 
fb432f9
cfe253c
fb432f9
 
 
 
 
 
cfe253c
fb432f9
cfe253c
8644072
 
 
 
 
 
e38c504
 
 
cfe253c
 
 
 
 
 
8644072
 
e38c504
cfe253c
 
 
 
1d0c47a
cfe253c
 
604941d
 
cfe253c
 
33127b8
cfe253c
33127b8
cfe253c
 
33127b8
 
 
 
 
cfe253c
 
 
33127b8
cfe253c
 
fb432f9
cfe253c
 
 
 
 
 
 
 
 
 
 
 
 
 
33127b8
cfe253c
 
 
 
 
 
 
 
 
 
33127b8
 
 
cfe253c
33127b8
cfe253c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import numpy
from transformers import TokenClassificationPipeline

class UniversalDependenciesPipeline(TokenClassificationPipeline):
  def preprocess(self,sentence,offset_mapping=None):
    import torch
    from tokenizers.pre_tokenizers import Whitespace
    v=Whitespace().pre_tokenize_str(sentence)
    t=[v[0]]
    for k,(s,e) in v[1:]:
      j=t[-1][0]+"_"+k
      if self.tokenizer.convert_tokens_to_ids(j)!=self.tokenizer.unk_token_id:
        t[-1]=(j,(t[-1][1][0],e))
      else:
        t.append((k,(s,e)))
    m=[(0,0)]+[j for i,j in t]+[(0,0)]
    r=super().preprocess(sentence=" ".join(i for i,j in t))
    try:
      f=True
      k=r["input_ids"]
    except:
      f=False
      r=list(r)[0]
      k=r["input_ids"]
    w=self.tokenizer.convert_ids_to_tokens(k[0])
    if len(m)!=len(w):
      for i,j in enumerate(w):
        if j.endswith("@@"):
          s,e=m[i]
          m.insert(i+1,(s+len(j)-2,e))
          m[i]=(s,s+len(j)-2)
    r["offset_mapping"]=torch.tensor([m]).to(self.device)
    r["sentence"]=sentence
    return r if f else iter([r])
  def _forward(self,model_inputs):
    import torch
    v=model_inputs["input_ids"][0].tolist()
    with torch.no_grad():
      e=self.model(input_ids=torch.tensor([v[0:i]+[self.tokenizer.mask_token_id]+v[i+1:]+[j] for i,j in enumerate(v[1:-1],1)],device=self.device))
    return {"logits":e.logits[:,1:-2,:],**model_inputs}
  def postprocess(self,model_outputs,**kwargs):
    if "logits" not in model_outputs:
      return "".join(self.postprocess(x,**kwargs) for x in model_outputs)
    e=model_outputs["logits"].numpy()
    r=[1 if i==0 else -1 if j.endswith("|root") else 0 for i,j in sorted(self.model.config.id2label.items())]
    e+=numpy.where(numpy.add.outer(numpy.identity(e.shape[0]),r)==0,0,-numpy.inf)
    g=self.model.config.label2id["X|_|goeswith"]
    m,r=numpy.max(e,axis=2),numpy.tri(e.shape[0])
    for i in range(e.shape[0]):
      for j in range(i+2,e.shape[1]):
        r[i,j]=1
        if numpy.argmax(e[i,j-1])==g and numpy.argmax(m[:,j-1])==i:
          r[i,j]=r[i,j-1]
    e[:,:,g]+=numpy.where(r==0,0,-numpy.inf)
    m,p=numpy.max(e,axis=2),numpy.argmax(e,axis=2)
    h=self.chu_liu_edmonds(m)
    z=[i for i,j in enumerate(h) if i==j]
    if len(z)>1:
      k,h=z[numpy.argmax(m[z,z])],numpy.min(m)-numpy.max(m)
      m[:,z]+=[[0 if j in z and (i!=j or i==k) else h for i in z] for j in range(m.shape[0])]
      h=self.chu_liu_edmonds(m)
    v=[(s,e) for s,e in model_outputs["offset_mapping"][0].tolist() if s<e]
    q=[self.model.config.id2label[p[j,i]].split("|") for i,j in enumerate(h)]
    g="aggregation_strategy" in kwargs and kwargs["aggregation_strategy"]!="none"
    if g:
      for i,j in reversed(list(enumerate(q[1:],1))):
        if j[-1]=="goeswith" and set([t[-1] for t in q[h[i]+1:i+1]])=={"goeswith"}:
          h=[b if i>b else b-1 for a,b in enumerate(h) if i!=a]
          v[i-1]=(v[i-1][0],v.pop(i)[1])
          q.pop(i)
    t=model_outputs["sentence"].replace("\n"," ")
    u="# text = "+t+"\n"
    for i,(s,e) in enumerate(v):
      u+="\t".join([str(i+1),t[s:e],t[s:e] if g else "_",q[i][0],"_","|".join(q[i][1:-1]),str(0 if h[i]==i else h[i]+1),q[i][-1],"_","_" if i+1<len(v) and e<v[i+1][0] else "SpaceAfter=No"])+"\n"
    return u+"\n"
  def chu_liu_edmonds(self,matrix):
    h=numpy.argmax(matrix,axis=0)
    x=[-1 if i==j else j for i,j in enumerate(h)]
    for b in [lambda x,i,j:-1 if i not in x else x[i],lambda x,i,j:-1 if j<0 else x[j]]:
      y=[]
      while x!=y:
        y=list(x)
        for i,j in enumerate(x):
          x[i]=b(x,i,j)
      if max(x)<0:
        return h
    y,x=[i for i,j in enumerate(x) if j==max(x)],[i for i,j in enumerate(x) if j<max(x)]
    z=matrix-numpy.max(matrix,axis=0)
    m=numpy.block([[z[x,:][:,x],numpy.max(z[x,:][:,y],axis=1).reshape(len(x),1)],[numpy.max(z[y,:][:,x],axis=0),numpy.max(z[y,y])]])
    k=[j if i==len(x) else x[j] if j<len(x) else y[numpy.argmax(z[y,x[i]])] for i,j in enumerate(self.chu_liu_edmonds(m))]
    h=[j if i in y else k[x.index(i)] for i,j in enumerate(h)]
    i=y[numpy.argmax(z[x[k[-1]],y] if k[-1]<len(x) else z[y,y])]
    h[i]=x[k[-1]] if k[-1]<len(x) else i
    return h