File size: 5,961 Bytes
a228fac
 
 
 
0dd8e27
a228fac
 
12af45e
a228fac
 
 
 
0d655c9
1be0846
 
a228fac
 
 
 
 
 
 
613ad82
a228fac
0dd8e27
a228fac
 
1be0846
a228fac
 
 
 
 
 
7b11b8d
 
 
 
a228fac
 
0d655c9
12af45e
 
 
 
 
 
 
 
a228fac
 
9d6cf42
 
12af45e
 
 
 
 
 
 
 
 
 
 
 
 
 
0d655c9
 
12af45e
 
 
 
 
1be0846
 
 
 
 
 
12af45e
1be0846
 
 
 
 
 
 
 
 
613ad82
1be0846
12af45e
1be0846
 
 
12af45e
 
1be0846
12af45e
a228fac
 
12af45e
a228fac
0dd8e27
a228fac
 
 
0dd8e27
a228fac
0dd8e27
a228fac
0dd8e27
 
 
 
 
 
 
 
 
 
 
a228fac
0dd8e27
a228fac
 
 
 
0dd8e27
 
 
 
 
 
 
a228fac
0dd8e27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
from config import Settings
from preprocess import Preprocessor
import ocr
from PIL import Image
from transformers import LiltForTokenClassification, AutoTokenizer
import token_classification
import torch
from fastapi import FastAPI, UploadFile, Form, HTTPException
from contextlib import asynccontextmanager
import json
import io
from models import LiLTRobertaLikeForRelationExtraction
from base64 import b64decode 
from handwritting_detection import DetectHandwritting
import pandas as pd
config = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
    settings = Settings()
    config['settings'] = settings
    config['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    config['processor'] = Preprocessor(settings.TOKENIZER)
    config['tokenizer'] = AutoTokenizer.from_pretrained(settings.TOKENIZER)
    config['ser_model'] = LiltForTokenClassification.from_pretrained(settings.SER_MODEL)
    config['re_model'] = LiLTRobertaLikeForRelationExtraction.from_pretrained(settings.RE_MODEL)
    config['TROCR_API'] = settings.TROCR_API_URL
    yield
    # Clean up and release the resources
    config.clear()

app = FastAPI(lifespan=lifespan)

@app.get("/")
def api_home():
    return {'detail': 'Welcome to Sri-Doc space'}

@app.post("/submit-doc")
async def ProcessDocument(file: UploadFile):
  content = await file.read()
  ocr_df, image = ApplyOCR(content)
  if len(ocr_df) < 2:
    raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
  try:
    tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
    reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
  except:
    raise HTTPException(status_code=400, detail="Invalid Image")
  return reOutput

@app.post("/submit-doc-base64")
async def ProcessDocument(file: str = Form(...)):
  try:
    head, file = file.split(',')
    str_as_bytes = str.encode(file)
    content = b64decode(str_as_bytes)
  except:
    raise HTTPException(status_code=400, detail="Invalid image")
  ocr_df, image = ApplyOCR(content)
  if len(ocr_df) < 2:
    raise HTTPException(status_code=400, detail="Cannot apply OCR to the image")
  try:
    tokenClassificationOutput, img_size = LabelTokens(ocr_df, image)
    reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
  except:
    raise HTTPException(status_code=400, detail="Invalid Image")
  return reOutput

def ApplyOCR(content):
  try:
    image = Image.open(io.BytesIO(content))
  except:
    raise HTTPException(status_code=400, detail="Invalid image")

  try:
    printed_img, handwritten_imgs = DetectHandwritting(image)
  except:
    raise HTTPException(status_code=400, detail="Handwritten OCR failed")

  try:
    trocr_client = ocr.TrOCRClientClient(config['settings'].TROCR_API_URL)
    handwritten_ocr_df = trocr_client.ocr(handwritten_imgs, image)
  except:
    raise HTTPException(status_code=400, detail="handwritten OCR process failed")

  try:
    jpeg_bytes = io.BytesIO()
    printed_img.save(jpeg_bytes, format='JPEG')
    jpeg_content = jpeg_bytes.getvalue()
    vision_client = ocr.VisionClient(config['settings'].GCV_AUTH)
    printed_ocr_df = vision_client.ocr(jpeg_content, printed_img)
  except:
    raise HTTPException(status_code=400, detail="Printed OCR process failed")
  
  ocr_df = pd.concat([handwritten_ocr_df, printed_ocr_df])
  return ocr_df, image


def LabelTokens(ocr_df, image):
  input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
  token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
  return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, image.size

def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
  token_labels = tokenClassificationOutput['token_labels']
  input_ids = tokenClassificationOutput['input_ids']
  attention_mask = tokenClassificationOutput["attention_mask"]
  bbox_org = tokenClassificationOutput["bbox"]

  merged_output, merged_words = token_classification.createEntities(config['ser_model'], token_labels, input_ids, ocr_df, config['tokenizer'], img_size, bbox_org)
  
  entities = merged_output['entities']
  input_ids = torch.tensor([merged_output['input_ids']]).to(config['device'])
  bbox = torch.tensor([merged_output['bbox']]).to(config['device'])
  attention_mask = torch.tensor([merged_output['attention_mask']]).to(config['device'])

  id2label = {"HEADER":0, "QUESTION":1, "ANSWER":2}
  decoded_entities = []
  for entity in entities:
    decoded_entities.append((entity['label'], config['tokenizer'].decode(input_ids[0][entity['start']:entity['end']])))
    entity['label'] = id2label[entity['label']]

  config['re_model'].to(config['device'])
  entity_dict = {'start': [entity['start'] for entity in entities], 'end': [entity['end'] for entity in entities], 'label': [entity['label'] for entity in entities]}
  relations = [{'start_index': [], 'end_index': [], 'head': [], 'tail': []}]
  with torch.no_grad():
    outputs = config['re_model'](input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, entities=[entity_dict], relations=relations)

  decoded_pred_relations = []
  for relation in outputs.pred_relations[0]:
    head_start, head_end = relation['head']
    tail_start, tail_end = relation['tail']
    question =  config['tokenizer'].decode(input_ids[0][head_start:head_end])
    answer = config['tokenizer'].decode(input_ids[0][tail_start:tail_end])
    decoded_pred_relations.append((question, answer))

  return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox_org.tolist()),"token_labels":json.dumps(token_labels), "decoded_entities": json.dumps(decoded_entities), "decoded_pred_relations":json.dumps(decoded_pred_relations)}