File size: 5,961 Bytes
a228fac 0dd8e27 a228fac 12af45e a228fac 0d655c9 1be0846 a228fac 613ad82 a228fac 0dd8e27 a228fac 1be0846 a228fac 7b11b8d a228fac 0d655c9 12af45e a228fac 9d6cf42 12af45e 0d655c9 12af45e 1be0846 12af45e 1be0846 613ad82 1be0846 12af45e 1be0846 12af45e 1be0846 12af45e a228fac 12af45e a228fac 0dd8e27 a228fac 0dd8e27 a228fac 0dd8e27 a228fac 0dd8e27 a228fac 0dd8e27 a228fac 0dd8e27 a228fac 0dd8e27 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 |
from config import Settings
from preprocess import Preprocessor
import ocr
from PIL import Image
from transformers import LiltForTokenClassification, AutoTokenizer
import token_classification
import torch
from fastapi import FastAPI, UploadFile, Form, HTTPException
from contextlib import asynccontextmanager
import json
import io
from models import LiLTRobertaLikeForRelationExtraction
from base64 import b64decode
from handwritting_detection import DetectHandwritting
import pandas as pd
config = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
settings = Settings()
config['settings'] = settings
config['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config['processor'] = Preprocessor(settings.TOKENIZER)
config['tokenizer'] = AutoTokenizer.from_pretrained(settings.TOKENIZER)
config['ser_model'] = LiltForTokenClassification.from_pretrained(settings.SER_MODEL)
config['re_model'] = LiLTRobertaLikeForRelationExtraction.from_pretrained(settings.RE_MODEL)
config['TROCR_API'] = settings.TROCR_API_URL
yield
# Clean up and release the resources
config.clear()
app = FastAPI(lifespan=lifespan)
@app.get("/")
def api_home():
return {'detail': 'Welcome to Sri-Doc space'}
@app.post("/submit-doc")
async def ProcessDocument(file: UploadFile):
content = await file.read()
ocr_df, image = ApplyOCR(content)
if len(ocr_df) < 2:
raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
try:
tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
except:
raise HTTPException(status_code=400, detail="Invalid Image")
return reOutput
@app.post("/submit-doc-base64")
async def ProcessDocument(file: str = Form(...)):
try:
head, file = file.split(',')
str_as_bytes = str.encode(file)
content = b64decode(str_as_bytes)
except:
raise HTTPException(status_code=400, detail="Invalid image")
ocr_df, image = ApplyOCR(content)
if len(ocr_df) < 2:
raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
try:
tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
except:
raise HTTPException(status_code=400, detail="Invalid Image")
return reOutput
def ApplyOCR(content):
try:
image = Image.open(io.BytesIO(content))
except:
raise HTTPException(status_code=400, detail="Invalid image")
try:
printed_img, handwritten_imgs = DetectHandwritting(image)
except:
raise HTTPException(status_code=400, detail="Handwritten OCR failed")
try:
trocr_client = ocr.TrOCRClientClient(config['settings'].TROCR_API_URL)
handwritten_ocr_df = trocr_client.ocr(handwritten_imgs, image)
except:
raise HTTPException(status_code=400, detail="handwritten OCR process failed")
try:
jpeg_bytes = io.BytesIO()
printed_img.save(jpeg_bytes, format='JPEG')
jpeg_content = jpeg_bytes.getvalue()
vision_client = ocr.VisionClient(config['settings'].GCV_AUTH)
printed_ocr_df = vision_client.ocr(jpeg_content, printed_img)
except:
raise HTTPException(status_code=400, detail="Printed OCR process failed")
ocr_df = pd.concat([handwritten_ocr_df, printed_ocr_df])
return ocr_df, image
def LabelTokens(ocr_df, image):
input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, image.size
def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
token_labels = tokenClassificationOutput['token_labels']
input_ids = tokenClassificationOutput['input_ids']
attention_mask = tokenClassificationOutput["attention_mask"]
bbox_org = tokenClassificationOutput["bbox"]
merged_output, merged_words = token_classification.createEntities(config['ser_model'], token_labels, input_ids, ocr_df, config['tokenizer'], img_size, bbox_org)
entities = merged_output['entities']
input_ids = torch.tensor([merged_output['input_ids']]).to(config['device'])
bbox = torch.tensor([merged_output['bbox']]).to(config['device'])
attention_mask = torch.tensor([merged_output['attention_mask']]).to(config['device'])
id2label = {"HEADER":0, "QUESTION":1, "ANSWER":2}
decoded_entities = []
for entity in entities:
decoded_entities.append((entity['label'], config['tokenizer'].decode(input_ids[0][entity['start']:entity['end']])))
entity['label'] = id2label[entity['label']]
config['re_model'].to(config['device'])
entity_dict = {'start': [entity['start'] for entity in entities], 'end': [entity['end'] for entity in entities], 'label': [entity['label'] for entity in entities]}
relations = [{'start_index': [], 'end_index': [], 'head': [], 'tail': []}]
with torch.no_grad():
outputs = config['re_model'](input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, entities=[entity_dict], relations=relations)
decoded_pred_relations = []
for relation in outputs.pred_relations[0]:
head_start, head_end = relation['head']
tail_start, tail_end = relation['tail']
question = config['tokenizer'].decode(input_ids[0][head_start:head_end])
answer = config['tokenizer'].decode(input_ids[0][tail_start:tail_end])
decoded_pred_relations.append((question, answer))
return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox_org.tolist()),"token_labels":json.dumps(token_labels), "decoded_entities": json.dumps(decoded_entities), "decoded_pred_relations":json.dumps(decoded_pred_relations)} |