ner-gradio / app.py
saadiiii's picture
Update app.py
bf1623e
raw
history blame
9.66 kB
import gradio as gr
import http.client
import json
from bs4 import BeautifulSoup as bs
import re
import pymongo
import torch
import spacy
from spacy import displacy
from pymongo import MongoClient
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("law-ai/InLegalBERT")
from transformers import BertForTokenClassification
class BertModel(torch.nn.Module):
def __init__(self):
super(BertModel, self).__init__()
self.bert = BertForTokenClassification.from_pretrained('law-ai/InLegalBERT', num_labels=14)
def forward(self, input_id, mask, label):
output = self.bert(input_ids=input_id, attention_mask=mask, labels=label, return_dict=False)
return output
model_preamble = BertModel()
model_preamble = torch.load("nerbert_preamble.pt", map_location=torch.device('cpu'))
model_judgment = BertModel()
model_judgment = torch.load("nerbert.pt", map_location=torch.device('cpu'))
unique_labels_preamble = {'I-PETITIONER', 'I-COURT', 'B-COURT', 'B-JUDGE', 'I-LAWYER', 'B-RESPONDENT', 'I-JUDGE', 'B-PETITIONER', 'I-RESPONDENT', 'B-LAWYER', 'O'}
unique_labels_judgment = {'B-WITNESS', 'I-PETITIONER', 'I-JUDGE', 'B-STATUTE', 'B-OTHER_PERSON', 'B-CASE_NUMBER', 'I-ORG', 'I-PRECEDENT', 'I-RESPONDENT', 'B-PROVISION', 'O', 'I-WITNESS', 'B-ORG', 'I-COURT', 'B-RESPONDENT', 'I-DATE', 'B-GPE', 'I-CASE_NUMBER', 'B-DATE', 'B-PRECEDENT', 'I-GPE', 'B-COURT', 'B-JUDGE', 'I-STATUTE', 'B-PETITIONER', 'I-OTHER_PERSON', 'I-PROVISION'}
labels_to_ids_preamble = {k: v for v, k in enumerate(sorted(unique_labels_preamble))}
ids_to_labels_preamble = {v: k for v, k in enumerate(sorted(unique_labels_preamble))}
labels_to_ids_judgment = {k: v for v, k in enumerate(sorted(unique_labels_judgment))}
ids_to_labels_judgment = {v: k for v, k in enumerate(sorted(unique_labels_judgment))}
label_all_tokens = True
def align_word_ids(texts):
tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True)
word_ids = tokenized_inputs.word_ids()
previous_word_idx = None
label_ids = []
for word_idx in word_ids:
if word_idx is None:
label_ids.append(-100)
elif word_idx != previous_word_idx:
try:
label_ids.append(1)
except:
label_ids.append(-100)
else:
try:
label_ids.append(1 if label_all_tokens else -100)
except:
label_ids.append(-100)
previous_word_idx = word_idx
return label_ids
def evaluate_one_preamble(model, sentence):
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
text = tokenizer(sentence, padding='max_length', max_length = 512, truncation=True, return_tensors="pt")
mask = text['attention_mask'].to(device)
input_id = text['input_ids'].to(device)
label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device)
logits = model(input_id, mask, None)
logits_clean = logits[0][label_ids != -100]
predictions = logits_clean.argmax(dim=1).tolist()
prediction_label = [ids_to_labels_preamble[i] for i in predictions]
print(sentence)
print(len(prediction_label))
print(prediction_label)
return (prediction_label,text)
def evaluate_one_text(model, sentence):
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
text = tokenizer(sentence, padding='max_length', max_length = 512, truncation=True, return_tensors="pt")
mask = text['attention_mask'].to(device)
input_id = text['input_ids'].to(device)
label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device)
logits = model(input_id, mask, None)
logits_clean = logits[0][label_ids != -100]
predictions = logits_clean.argmax(dim=1).tolist()
prediction_label = [ids_to_labels_judgment[i] for i in predictions]
print(sentence)
print(len(prediction_label))
print(prediction_label)
return (prediction_label,text)
def cleanhtml(raw_html):
CLEANR = re.compile('<.*?>')
cleantext = re.sub(CLEANR, '', raw_html)
return cleantext
nlp = spacy.blank("en")
def judgtext_analysis(text):
conn = http.client.HTTPSConnection("api.indiankanoon.org")
payload = "{}"
headers = {
'Authorization': 'Token ea381f5b51f9d55aaa71dfe6a90606e9b89f942a',
'Content-Type': 'application/json'
}
#Parse text and retrieve the document id
d = text.split('/')
docid = d[4]
endpoint="/doc/"+str(docid)+"/"
conn.request("POST", endpoint, payload, headers)
res = conn.getresponse()
data = res.read()
data = data.decode("utf-8")
data_dict = json.loads(data)
soup = bs(data_dict["doc"], 'html.parser')
judgment_text=""
for tag in soup.find_all(['p', 'blockquote']):
judgment_text+=(tag.text)+" "
judgment_text = cleanhtml(str(judgment_text))
preamble_text = soup.find("pre")
preamble_text = cleanhtml(str(preamble_text))
judgment_sentences = sentences = re.split(r' *[\.\?!][\'"\)\]]* *', judgment_text)
finalentities=[]
finaltext=""
labellist,text_tokenized = evaluate_one_preamble(model_preamble,preamble_text)
tokenlist = tokenizer.convert_ids_to_tokens(text_tokenized["input_ids"][0])
finallist=[]
for i in range(1,len(tokenlist)):
if(tokenlist[i]=='[SEP]'):
break
finallist.append(tokenlist[i])
finalstring=""
i=0
finallistshortened=[]
labellistshortened=[]
while(i<len(finallist)):
word=""
word+=finallist[i]
j=i+1
labellistshortened.append(labellist[i])
while(j<len(finallist) and finallist[j].startswith("##")):
word+=finallist[j][2:]
j+=1
finalstring+=word
finallistshortened.append(word)
finalstring+=" "
i=j
text=""
entities=[]
i=0
while(i<len(finallistshortened)):
word=""
start=len(text)
word+=finallistshortened[i]+" "
j=i+1
if(labellistshortened[i]=="O"):
i+=1
text+=word+" "
continue
entity = labellistshortened[i][2:]
ientity = "I-"+entity
while(j<len(finallistshortened) and labellistshortened[j]==ientity):
word+=finallistshortened[j]+ " "
j+=1
text+=word+" "
prevstart=len(finaltext)
end=len(text)-2
finalstring+=text + ". "
entities.append((entity,start,end))
finalentities.append((entity,prevstart + start,prevstart + end))
i=j
finaltext+=text + ". "
for sentence in judgment_sentences:
labellist,text_tokenized = evaluate_one_text(model_judgment,sentence)
tokenlist = tokenizer.convert_ids_to_tokens(text_tokenized["input_ids"][0])
finallist=[]
for i in range(1,len(tokenlist)):
if(tokenlist[i]=='[SEP]'):
break
finallist.append(tokenlist[i])
finalstring=""
i=0
finallistshortened=[]
labellistshortened=[]
# if(len(labellist) !=len(finallist)):
# print("Unmatching sentence: " + sentence)
# if(len(labellist) < len(finallist)):
# while(len(labellist) < len(finallist)):
# labellist.append("O")
# if(len(labellist) > len(finallist)):
# while(len(labellist) > len(finallist)):
# finallist.append("O")
while(i<len(finallist)):
word=""
word+=finallist[i]
j=i+1
labellistshortened.append(labellist[i])
while(j<len(finallist) and finallist[j].startswith("##")):
word+=finallist[j][2:]
j+=1
finalstring+=word
finallistshortened.append(word)
finalstring+=" "
i=j
text=""
entities=[]
i=0
while(i<len(finallistshortened)):
word=""
start=len(text)
word+=finallistshortened[i]+" "
j=i+1
if(labellistshortened[i]=="O"):
i+=1
text+=word+" "
continue
entity = labellistshortened[i][2:]
ientity = "I-"+entity
while(j<len(finallistshortened) and labellistshortened[j]==ientity):
word+=finallistshortened[j]+ " "
j+=1
text+=word+" "
prevstart=len(finaltext)
end=len(text)-2
finalstring+=text + ". "
entities.append((entity,start,end))
finalentities.append((entity,prevstart + start,prevstart + end))
i=j
finaltext+=text + ". "
doc = nlp(finaltext)
ents = []
for ee in finalentities:
ents.append(doc.char_span(ee[1], ee[2], ee[0]))
doc.ents = ents
#logic for repository
cluster = MongoClient("mongodb+srv://testuser:[email protected]/?retryWrites=true&w=majority")
db = cluster["nerdb"]
collection = db["named_entities"]
content = displacy.render(doc, style='ent')
extsoup = bs(content, 'html.parser')
txtlist=[]
entlist=[]
for h in extsoup.findAll('div'):
mark = h.findAll('mark')
span = h.findAll('span')
for i in mark:
txt = i.find(text=True)
txt = txt.replace("\n","");
txt = txt.strip();
txtlist.append(str(txt))
for i in span:
ent = i.find(text=True)
entlist.append(str(ent))
zipped = zip(txtlist, entlist)
ziplist = list(zipped)
post = {"api_docid":docid, "document_text":finaltext, "named_entities":ziplist, "entities_indices": finalentities}
collection.insert_one(post)
html = displacy.render(doc, style="ent", page = True)
html = (
""
+ html
+ ""
)
pos_count = {
"char_count": len(text),
"token_count": 0,
}
pos_tokens = []
for token in doc:
pos_tokens.extend([(token.text, token.pos_), (" ", None)])
return html
demo = gr.Interface(
judgtext_analysis,
gr.Textbox(placeholder="Enter Indian Kanoon document URL here..."),
["html"],
)
demo.launch(inline=False)