import gradio as gr import http.client import json from bs4 import BeautifulSoup as bs import re import pymongo import torch import spacy from spacy import displacy # from pymongo import MongoClient from transformers import AutoTokenizer tokenizer = AutoTokenizer.from_pretrained("law-ai/InLegalBERT") from transformers import BertForTokenClassification class BertModel(torch.nn.Module): def __init__(self): super(BertModel, self).__init__() self.bert = BertForTokenClassification.from_pretrained('law-ai/InLegalBERT', num_labels=14) def forward(self, input_id, mask, label): output = self.bert(input_ids=input_id, attention_mask=mask, labels=label, return_dict=False) return output model_preamble = BertModel() model_preamble = torch.load("nerbert_preamble.pt", map_location=torch.device('cpu')) model_judgment = BertModel() model_judgment = torch.load("nerbert.pt", map_location=torch.device('cpu')) unique_labels_preamble = {'I-PETITIONER', 'I-COURT', 'B-COURT', 'B-JUDGE', 'I-LAWYER', 'B-RESPONDENT', 'I-JUDGE', 'B-PETITIONER', 'I-RESPONDENT', 'B-LAWYER', 'O'} unique_labels_judgment = {'B-WITNESS', 'I-PETITIONER', 'I-JUDGE', 'B-STATUTE', 'B-OTHER_PERSON', 'B-CASE_NUMBER', 'I-ORG', 'I-PRECEDENT', 'I-RESPONDENT', 'B-PROVISION', 'O', 'I-WITNESS', 'B-ORG', 'I-COURT', 'B-RESPONDENT', 'I-DATE', 'B-GPE', 'I-CASE_NUMBER', 'B-DATE', 'B-PRECEDENT', 'I-GPE', 'B-COURT', 'B-JUDGE', 'I-STATUTE', 'B-PETITIONER', 'I-OTHER_PERSON', 'I-PROVISION'} labels_to_ids_preamble = {k: v for v, k in enumerate(sorted(unique_labels_preamble))} ids_to_labels_preamble = {v: k for v, k in enumerate(sorted(unique_labels_preamble))} labels_to_ids_judgment = {k: v for v, k in enumerate(sorted(unique_labels_judgment))} ids_to_labels_judgment = {v: k for v, k in enumerate(sorted(unique_labels_judgment))} label_all_tokens = True def align_word_ids(texts): tokenized_inputs = tokenizer(texts, padding='max_length', max_length=512, truncation=True) word_ids = tokenized_inputs.word_ids() previous_word_idx = None label_ids = [] for word_idx in word_ids: if word_idx is None: label_ids.append(-100) elif word_idx != previous_word_idx: try: label_ids.append(1) except: label_ids.append(-100) else: try: label_ids.append(1 if label_all_tokens else -100) except: label_ids.append(-100) previous_word_idx = word_idx return label_ids def evaluate_one_preamble(model, sentence): use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") text = tokenizer(sentence, padding='max_length', max_length = 512, truncation=True, return_tensors="pt") mask = text['attention_mask'].to(device) input_id = text['input_ids'].to(device) label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device) logits = model(input_id, mask, None) logits_clean = logits[0][label_ids != -100] predictions = logits_clean.argmax(dim=1).tolist() prediction_label = [ids_to_labels_preamble[i] for i in predictions] return (prediction_label,text) def evaluate_one_text(model, sentence): use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") text = tokenizer(sentence, padding='max_length', max_length = 512, truncation=True, return_tensors="pt") mask = text['attention_mask'].to(device) input_id = text['input_ids'].to(device) label_ids = torch.Tensor(align_word_ids(sentence)).unsqueeze(0).to(device) logits = model(input_id, mask, None) logits_clean = logits[0][label_ids != -100] predictions = logits_clean.argmax(dim=1).tolist() prediction_label = [ids_to_labels_judgment[i] for i in predictions] return (prediction_label,text) def cleanhtml(raw_html): CLEANR = re.compile('<.*?>') cleantext = re.sub(CLEANR, '', raw_html) return cleantext nlp = spacy.blank("en") def judgtext_analysis(text): conn = http.client.HTTPSConnection("api.indiankanoon.org") payload = "{}" headers = { 'Authorization': 'Token ea381f5b51f9d55aaa71dfe6a90606e9b89f942a', 'Content-Type': 'application/json' } #Parse text and retrieve the document id d = text.split('/') docid = d[4] endpoint="/doc/"+str(docid)+"/" conn.request("POST", endpoint, payload, headers) res = conn.getresponse() data = res.read() data = data.decode("utf-8") data_dict = json.loads(data) soup = bs(data_dict["doc"], 'html.parser') judgment_text="" for tag in soup.find_all(['p', 'blockquote']): judgment_text+=(tag.text)+" " judgment_text = cleanhtml(str(judgment_text)) preamble_text = soup.find("pre") preamble_text = cleanhtml(str(preamble_text)) judgment_sentences = sentences = re.split(r' *[\.\?!][\'"\)\]]* *', judgment_text) finalentities=[] finaltext="" labellist,text_tokenized = evaluate_one_preamble(model_preamble,preamble_text) tokenlist = tokenizer.convert_ids_to_tokens(text_tokenized["input_ids"][0]) finallist=[] for i in range(1,len(tokenlist)): if(tokenlist[i]=='[SEP]'): break finallist.append(tokenlist[i]) finalstring="" i=0 finallistshortened=[] labellistshortened=[] while(i