File size: 4,404 Bytes
a228fac 0dd8e27 a228fac 0d655c9 a228fac 0d655c9 a228fac 0dd8e27 a228fac 0d655c9 0dd8e27 a228fac 9d6cf42 7c6755e 9d6cf42 0d655c9 a228fac 7c6755e a228fac 0dd8e27 a228fac 0dd8e27 a228fac 0dd8e27 a228fac 0dd8e27 a228fac 0dd8e27 a228fac 0dd8e27 a228fac 0dd8e27 a228fac 0dd8e27 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 |
from config import Settings
from preprocess import Preprocessor
import ocr
from PIL import Image
from transformers import LiltForTokenClassification, AutoTokenizer
import token_classification
import torch
from fastapi import FastAPI, UploadFile, Form
from contextlib import asynccontextmanager
import json
import io
from models import LiLTRobertaLikeForRelationExtraction
from base64 import b64decode
config = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
settings = Settings()
config['settings'] = settings
config['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config['vision_client'] = ocr.VisionClient(settings.GCV_AUTH)
config['processor'] = Preprocessor(settings.TOKENIZER)
config['tokenizer'] = AutoTokenizer.from_pretrained(settings.TOKENIZER)
config['ser_model'] = LiltForTokenClassification.from_pretrained(settings.SER_MODEL)
config['re_model'] = LiLTRobertaLikeForRelationExtraction.from_pretrained(settings.RE_MODEL)
yield
# Clean up and release the resources
config.clear()
app = FastAPI(lifespan=lifespan)
@app.post("/submit-doc")
async def ProcessDocument(file: UploadFile):
content = await file.read()
tokenClassificationOutput, ocr_df, img_size = LabelTokens(content)
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
return reOutput
@app.post("/submit-doc-base64")
async def ProcessDocument(file: str = Form(...)):
head, file = file.split(',')
str_as_bytes = str.encode(file)
content = b64decode(str_as_bytes)
tokenClassificationOutput, ocr_df, img_size = LabelTokens(content)
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
return reOutput
def LabelTokens(content):
image = Image.open(io.BytesIO(content))
image.show()
ocr_df = config['vision_client'].ocr(content, image)
input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, ocr_df, image.size
def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
token_labels = tokenClassificationOutput['token_labels']
input_ids = tokenClassificationOutput['input_ids']
attention_mask = tokenClassificationOutput["attention_mask"]
bbox_org = tokenClassificationOutput["bbox"]
merged_output, merged_words = token_classification.createEntities(config['ser_model'], token_labels, input_ids, ocr_df, config['tokenizer'], img_size, bbox_org)
entities = merged_output['entities']
input_ids = torch.tensor([merged_output['input_ids']]).to(config['device'])
bbox = torch.tensor([merged_output['bbox']]).to(config['device'])
attention_mask = torch.tensor([merged_output['attention_mask']]).to(config['device'])
id2label = {"HEADER":0, "QUESTION":1, "ANSWER":2}
decoded_entities = []
for entity in entities:
decoded_entities.append((entity['label'], config['tokenizer'].decode(input_ids[0][entity['start']:entity['end']])))
entity['label'] = id2label[entity['label']]
config['re_model'].to(config['device'])
entity_dict = {'start': [entity['start'] for entity in entities], 'end': [entity['end'] for entity in entities], 'label': [entity['label'] for entity in entities]}
relations = [{'start_index': [], 'end_index': [], 'head': [], 'tail': []}]
with torch.no_grad():
outputs = config['re_model'](input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, entities=[entity_dict], relations=relations)
decoded_pred_relations = []
for relation in outputs.pred_relations[0]:
head_start, head_end = relation['head']
tail_start, tail_end = relation['tail']
question = config['tokenizer'].decode(input_ids[0][head_start:head_end])
answer = config['tokenizer'].decode(input_ids[0][tail_start:tail_end])
decoded_pred_relations.append((question, answer))
return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox_org.tolist()),"token_labels":json.dumps(token_labels), "decoded_entities": json.dumps(decoded_entities), "decoded_pred_relations":json.dumps(decoded_pred_relations)} |