Spaces:
Runtime error
Runtime error
import gradio as gr | |
import http.client | |
import json | |
from bs4 import BeautifulSoup as bs | |
import re | |
import pymongo | |
import torch | |
import spacy | |
from spacy import displacy | |
# from pymongo import MongoClient | |
from transformers import AutoTokenizer | |
tokenizer = AutoTokenizer.from_pretrained("law-ai/InLegalBERT") | |
from transformers import BertForTokenClassification | |
class BertModel(torch.nn.Module): | |
def __init__(self): | |
super(BertModel, self).__init__() | |
self.bert = BertForTokenClassification.from_pretrained('law-ai/InLegalBERT', num_labels=14) | |
def forward(self, input_id, mask, label): | |
output = self.bert(input_ids=input_id, attention_mask=mask, labels=label, return_dict=False) | |
return output | |
model_preamble = BertModel() | |
model_preamble = torch.load("nerbert_preamble.pt", map_location=torch.device('cpu')) | |
model_judgment = BertModel() | |
model_judgment = torch.load("nerbert.pt", map_location=torch.device('cpu')) | |
unique_labels_preamble = {'I-PETITIONER', 'I-COURT', 'B-COURT', 'B-JUDGE', 'I-LAWYER', 'B-RESPONDENT', 'I-JUDGE', 'B-PETITIONER', 'I-RESPONDENT', 'B-LAWYER', 'O'} | |
unique_labels_judgment = {'B-WITNESS', 'I-PETITIONER', 'I-JUDGE', 'B-STATUTE', 'B-OTHER_PERSON', 'B-CASE_NUMBER', 'I-ORG', 'I-PRECEDENT', 'I-RESPONDENT', 'B-PROVISION', 'O', 'I-WITNESS', 'B-ORG', 'I-COURT', 'B-RESPONDENT', 'I-DATE', 'B-GPE', 'I-CASE_NUMBER', 'B-DATE', 'B-PRECEDENT', 'I-GPE', 'B-COURT', 'B-JUDGE', 'I-STATUTE', 'B-PETITIONER', 'I-OTHER_PERSON', 'I-PROVISION'} | |
labels_to_ids_preamble = {k: v for v, k in enumerate(sorted(unique_labels_preamble))} | |
ids_to_labels_preamble = {v: k for v, k in enumerate(sorted(unique_labels_preamble))} | |
labels_to_ids_judgment = {k: v for v, k in enumerate(sorted(unique_labels_judgment))} | |
ids_to_labels_judgment = {v: k for v, k in enumerate(sorted(unique_labels_judgment))} | |
label_all_tokens = True | |
def align_word_ids(texts): | |
tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True) | |
word_ids = tokenized_inputs.word_ids() | |
previous_word_idx = None | |
label_ids = [] | |
for word_idx in word_ids: | |
if word_idx is None: | |
label_ids.append(-100) | |
elif word_idx != previous_word_idx: | |
try: | |
label_ids.append(1) | |
except: | |
label_ids.append(-100) | |
else: | |
try: | |
label_ids.append(1 if label_all_tokens else -100) | |
except: | |
label_ids.append(-100) | |
previous_word_idx = word_idx | |
return label_ids | |
def evaluate_one_preamble(model, sentence): | |
use_cuda = torch.cuda.is_available() | |
device = torch.device("cuda" if use_cuda else "cpu") | |
text = tokenizer(sentence, padding='max_length', max_length = 512, truncation=True, return_tensors="pt") | |
mask = text['attention_mask'].to(device) | |
input_id = text['input_ids'].to(device) | |
label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device) | |
logits = model(input_id, mask, None) | |
logits_clean = logits[0][label_ids != -100] | |
predictions = logits_clean.argmax(dim=1).tolist() | |
prediction_label = [ids_to_labels_preamble[i] for i in predictions] | |
return (prediction_label,text) | |
def evaluate_one_text(model, sentence): | |
use_cuda = torch.cuda.is_available() | |
device = torch.device("cuda" if use_cuda else "cpu") | |
text = tokenizer(sentence, padding='max_length', max_length = 512, truncation=True, return_tensors="pt") | |
mask = text['attention_mask'].to(device) | |
input_id = text['input_ids'].to(device) | |
label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device) | |
logits = model(input_id, mask, None) | |
logits_clean = logits[0][label_ids != -100] | |
predictions = logits_clean.argmax(dim=1).tolist() | |
prediction_label = [ids_to_labels_judgment[i] for i in predictions] | |
return (prediction_label,text) | |
def cleanhtml(raw_html): | |
CLEANR = re.compile('<.*?>') | |
cleantext = re.sub(CLEANR, '', raw_html) | |
return cleantext | |
nlp = spacy.blank("en") | |
def judgtext_analysis(text): | |
conn = http.client.HTTPSConnection("api.indiankanoon.org") | |
payload = "{}" | |
headers = { | |
'Authorization': 'Token ea381f5b51f9d55aaa71dfe6a90606e9b89f942a', | |
'Content-Type': 'application/json' | |
} | |
#Parse text and retrieve the document id | |
d = text.split('/') | |
docid = d[4] | |
endpoint="/doc/"+str(docid)+"/" | |
conn.request("POST", endpoint, payload, headers) | |
res = conn.getresponse() | |
data = res.read() | |
data = data.decode("utf-8") | |
data_dict = json.loads(data) | |
soup = bs(data_dict["doc"], 'html.parser') | |
judgment_text="" | |
for tag in soup.find_all(['p', 'blockquote']): | |
judgment_text+=(tag.text)+" " | |
judgment_text = cleanhtml(str(judgment_text)) | |
preamble_text = soup.find("pre") | |
preamble_text = cleanhtml(str(preamble_text)) | |
judgment_sentences = sentences = re.split(r' *[\.\?!][\'"\)\]]* *', judgment_text) | |
finalentities=[] | |
finaltext="" | |
labellist,text_tokenized = evaluate_one_preamble(model_preamble,preamble_text) | |
tokenlist = tokenizer.convert_ids_to_tokens(text_tokenized["input_ids"][0]) | |
finallist=[] | |
for i in range(1,len(tokenlist)): | |
if(tokenlist[i]=='[SEP]'): | |
break | |
finallist.append(tokenlist[i]) | |
finalstring="" | |
i=0 | |
finallistshortened=[] | |
labellistshortened=[] | |
while(i<len(finallist)): | |
word="" | |
word+=finallist[i] | |
j=i+1 | |
labellistshortened.append(labellist[i]) | |
while(j<len(finallist) and finallist[j].startswith("##")): | |
word+=finallist[j][2:] | |
j+=1 | |
finalstring+=word | |
finallistshortened.append(word) | |
finalstring+=" " | |
i=j | |
text="" | |
entities=[] | |
i=0 | |
while(i<len(finallistshortened)): | |
word="" | |
start=len(text) | |
word+=finallistshortened[i]+" " | |
j=i+1 | |
if(labellistshortened[i]=="O"): | |
i+=1 | |
text+=word+" " | |
continue | |
entity = labellistshortened[i][2:] | |
ientity = "I-"+entity | |
while(j<len(finallistshortened) and labellistshortened[j]==ientity): | |
word+=finallistshortened[j]+ " " | |
j+=1 | |
text+=word+" " | |
prevstart=len(finaltext) | |
end=len(text)-2 | |
finalstring+=text + ". " | |
entities.append((entity,start,end)) | |
finalentities.append((entity,prevstart + start,prevstart + end)) | |
i=j | |
finaltext+=text + ". " | |
for sentence in judgment_sentences: | |
labellist,text_tokenized = evaluate_one_text(model_judgment,sentence) | |
tokenlist = tokenizer.convert_ids_to_tokens(text_tokenized["input_ids"][0]) | |
finallist=[] | |
for i in range(1,len(tokenlist)): | |
if(tokenlist[i]=='[SEP]'): | |
break | |
finallist.append(tokenlist[i]) | |
finalstring="" | |
i=0 | |
finallistshortened=[] | |
labellistshortened=[] | |
while(i<len(finallist)): | |
word="" | |
word+=finallist[i] | |
j=i+1 | |
labellistshortened.append(labellist[i]) | |
while(j<len(finallist) and finallist[j].startswith("##")): | |
word+=finallist[j][2:] | |
j+=1 | |
finalstring+=word | |
finallistshortened.append(word) | |
finalstring+=" " | |
i=j | |
text="" | |
entities=[] | |
i=0 | |
while(i<len(finallistshortened)): | |
word="" | |
start=len(text) | |
word+=finallistshortened[i]+" " | |
j=i+1 | |
if(labellistshortened[i]=="O"): | |
i+=1 | |
text+=word+" " | |
continue | |
entity = labellistshortened[i][2:] | |
ientity = "I-"+entity | |
while(j<len(finallistshortened) and labellistshortened[j]==ientity): | |
word+=finallistshortened[j]+ " " | |
j+=1 | |
text+=word+" " | |
prevstart=len(finaltext) | |
end=len(text)-2 | |
finalstring+=text + ". " | |
entities.append((entity,start,end)) | |
finalentities.append((entity,prevstart + start,prevstart + end)) | |
i=j | |
finaltext+=text + ". " | |
doc = nlp(finaltext) | |
ents = [] | |
for ee in finalentities: | |
ents.append(doc.char_span(ee[1], ee[2], ee[0])) | |
doc.ents = ents | |
#logic for repository | |
# cluster = MongoClient("mongodb+srv://testuser:[email protected]/?retryWrites=true&w=majority") | |
# db = cluster["nerdb"] | |
# collection = db["named_entities"] | |
content = displacy.render(doc, style='ent') | |
# extsoup = bs(content, 'html.parser') | |
# txtlist=[] | |
# entlist=[] | |
# for h in extsoup.findAll('div'): | |
# mark = h.findAll('mark') | |
# span = h.findAll('span') | |
# for i in mark: | |
# txt = i.find(text=True) | |
# txt = txt.replace("\n",""); | |
# txt = txt.strip(); | |
# txtlist.append(str(txt)) | |
# for i in span: | |
# ent = i.find(text=True) | |
# entlist.append(str(ent)) | |
# zipped = zip(txtlist, entlist) | |
# ziplist = list(zipped) | |
# post = {"api_docid":docid, "document_text":finaltext, "named_entities":ziplist, "entities_indices": finalentities} | |
# collection.insert_one(post) | |
html = displacy.render(doc, style="ent", page = True) | |
html = ( | |
"" | |
+ html | |
+ "" | |
) | |
pos_count = { | |
"char_count": len(text), | |
"token_count": 0, | |
} | |
pos_tokens = [] | |
for token in doc: | |
pos_tokens.extend([(token.text, token.pos_), (" ", None)]) | |
return html | |
demo = gr.Interface( | |
judgtext_analysis, | |
gr.Textbox(placeholder="Enter Indian Kanoon document URL here..."), | |
["html"], | |
) | |
demo.launch(inline=False) |