sri-doc / main.py
kavg's picture
Added a home route
7b11b8d
raw
history blame
5.24 kB
from config import Settings
from preprocess import Preprocessor
import ocr
from PIL import Image
from transformers import LiltForTokenClassification, AutoTokenizer
import token_classification
import torch
from fastapi import FastAPI, UploadFile, Form, HTTPException
from contextlib import asynccontextmanager
import json
import io
from models import LiLTRobertaLikeForRelationExtraction
from base64 import b64decode
config = {}
@asynccontextmanager
async def lifespan(app: FastAPI):
settings = Settings()
config['settings'] = settings
config['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config['vision_client'] = ocr.VisionClient(settings.GCV_AUTH)
config['processor'] = Preprocessor(settings.TOKENIZER)
config['tokenizer'] = AutoTokenizer.from_pretrained(settings.TOKENIZER)
config['ser_model'] = LiltForTokenClassification.from_pretrained(settings.SER_MODEL)
config['re_model'] = LiLTRobertaLikeForRelationExtraction.from_pretrained(settings.RE_MODEL)
yield
# Clean up and release the resources
config.clear()
app = FastAPI(lifespan=lifespan)
@app.get("/")
def api_home():
return {'detail': 'Welcome to Sri-Doc space'}
@app.post("/submit-doc")
async def ProcessDocument(file: UploadFile):
content = await file.read()
ocr_df, image = ApplyOCR(content)
if len(ocr_df) < 2:
raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
try:
tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
except:
raise HTTPException(status_code=400, detail="Invalid Image")
return reOutput
@app.post("/submit-doc-base64")
async def ProcessDocument(file: str = Form(...)):
try:
head, file = file.split(',')
str_as_bytes = str.encode(file)
content = b64decode(str_as_bytes)
except:
raise HTTPException(status_code=400, detail="Invalid image")
ocr_df, image = ApplyOCR(content)
if len(ocr_df) < 2:
raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
try:
tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
except:
raise HTTPException(status_code=400, detail="Invalid Image")
return reOutput
def ApplyOCR(content):
try:
image = Image.open(io.BytesIO(content))
except:
raise HTTPException(status_code=400, detail="Invalid image")
try:
ocr_df = config['vision_client'].ocr(content, image)
except:
raise HTTPException(status_code=400, detail="OCR process failed")
return ocr_df, image
def LabelTokens(ocr_df, image):
input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, image.size
def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
token_labels = tokenClassificationOutput['token_labels']
input_ids = tokenClassificationOutput['input_ids']
attention_mask = tokenClassificationOutput["attention_mask"]
bbox_org = tokenClassificationOutput["bbox"]
merged_output, merged_words = token_classification.createEntities(config['ser_model'], token_labels, input_ids, ocr_df, config['tokenizer'], img_size, bbox_org)
entities = merged_output['entities']
input_ids = torch.tensor([merged_output['input_ids']]).to(config['device'])
bbox = torch.tensor([merged_output['bbox']]).to(config['device'])
attention_mask = torch.tensor([merged_output['attention_mask']]).to(config['device'])
id2label = {"HEADER":0, "QUESTION":1, "ANSWER":2}
decoded_entities = []
for entity in entities:
decoded_entities.append((entity['label'], config['tokenizer'].decode(input_ids[0][entity['start']:entity['end']])))
entity['label'] = id2label[entity['label']]
config['re_model'].to(config['device'])
entity_dict = {'start': [entity['start'] for entity in entities], 'end': [entity['end'] for entity in entities], 'label': [entity['label'] for entity in entities]}
relations = [{'start_index': [], 'end_index': [], 'head': [], 'tail': []}]
with torch.no_grad():
outputs = config['re_model'](input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, entities=[entity_dict], relations=relations)
decoded_pred_relations = []
for relation in outputs.pred_relations[0]:
head_start, head_end = relation['head']
tail_start, tail_end = relation['tail']
question = config['tokenizer'].decode(input_ids[0][head_start:head_end])
answer = config['tokenizer'].decode(input_ids[0][tail_start:tail_end])
decoded_pred_relations.append((question, answer))
return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox_org.tolist()),"token_labels":json.dumps(token_labels), "decoded_entities": json.dumps(decoded_entities), "decoded_pred_relations":json.dumps(decoded_pred_relations)}