File size: 4,404 Bytes
a228fac
 
 
 
0dd8e27
a228fac
 
0d655c9
a228fac
 
 
 
0d655c9
a228fac
 
 
 
 
 
 
 
 
0dd8e27
a228fac
 
 
 
 
 
 
 
 
 
0d655c9
 
0dd8e27
a228fac
 
9d6cf42
 
7c6755e
9d6cf42
0d655c9
 
 
 
 
 
a228fac
7c6755e
a228fac
 
 
0dd8e27
a228fac
0dd8e27
a228fac
 
 
0dd8e27
a228fac
0dd8e27
a228fac
0dd8e27
 
 
 
 
 
 
 
 
 
 
a228fac
0dd8e27
a228fac
 
 
 
0dd8e27
 
 
 
 
 
 
a228fac
0dd8e27
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from config import Settings
from preprocess import Preprocessor
import ocr
from PIL import Image
from transformers import LiltForTokenClassification, AutoTokenizer
import token_classification
import torch
from fastapi import FastAPI, UploadFile, Form
from contextlib import asynccontextmanager
import json
import io
from models import LiLTRobertaLikeForRelationExtraction
from base64 import b64decode 
config = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
    settings = Settings()
    config['settings'] = settings
    config['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    config['vision_client'] = ocr.VisionClient(settings.GCV_AUTH)
    config['processor'] = Preprocessor(settings.TOKENIZER)
    config['tokenizer'] = AutoTokenizer.from_pretrained(settings.TOKENIZER)
    config['ser_model'] = LiltForTokenClassification.from_pretrained(settings.SER_MODEL)
    config['re_model'] = LiLTRobertaLikeForRelationExtraction.from_pretrained(settings.RE_MODEL)
    yield
    # Clean up and release the resources
    config.clear()

app = FastAPI(lifespan=lifespan)

@app.post("/submit-doc")
async def ProcessDocument(file: UploadFile):
  content = await file.read()
  tokenClassificationOutput, ocr_df, img_size = LabelTokens(content)
  reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
  return reOutput

@app.post("/submit-doc-base64")
async def ProcessDocument(file: str = Form(...)):
  head, file = file.split(',')
  str_as_bytes = str.encode(file)
  content = b64decode(str_as_bytes)
  tokenClassificationOutput, ocr_df, img_size = LabelTokens(content)
  reOutput = ExtractRelations(tokenClassificationOutput, ocr_df, img_size)
  return reOutput

def LabelTokens(content):
  image = Image.open(io.BytesIO(content))
  image.show()
  ocr_df = config['vision_client'].ocr(content, image)
  input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
  token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
  return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "attention_mask":attention_mask}, ocr_df, image.size

def ExtractRelations(tokenClassificationOutput, ocr_df, img_size):
  token_labels = tokenClassificationOutput['token_labels']
  input_ids = tokenClassificationOutput['input_ids']
  attention_mask = tokenClassificationOutput["attention_mask"]
  bbox_org = tokenClassificationOutput["bbox"]

  merged_output, merged_words = token_classification.createEntities(config['ser_model'], token_labels, input_ids, ocr_df, config['tokenizer'], img_size, bbox_org)
  
  entities = merged_output['entities']
  input_ids = torch.tensor([merged_output['input_ids']]).to(config['device'])
  bbox = torch.tensor([merged_output['bbox']]).to(config['device'])
  attention_mask = torch.tensor([merged_output['attention_mask']]).to(config['device'])

  id2label = {"HEADER":0, "QUESTION":1, "ANSWER":2}
  decoded_entities = []
  for entity in entities:
    decoded_entities.append((entity['label'], config['tokenizer'].decode(input_ids[0][entity['start']:entity['end']])))
    entity['label'] = id2label[entity['label']]

  config['re_model'].to(config['device'])
  entity_dict = {'start': [entity['start'] for entity in entities], 'end': [entity['end'] for entity in entities], 'label': [entity['label'] for entity in entities]}
  relations = [{'start_index': [], 'end_index': [], 'head': [], 'tail': []}]
  with torch.no_grad():
    outputs = config['re_model'](input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, entities=[entity_dict], relations=relations)

  decoded_pred_relations = []
  for relation in outputs.pred_relations[0]:
    head_start, head_end = relation['head']
    tail_start, tail_end = relation['tail']
    question =  config['tokenizer'].decode(input_ids[0][head_start:head_end])
    answer = config['tokenizer'].decode(input_ids[0][tail_start:tail_end])
    decoded_pred_relations.append((question, answer))

  return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox_org.tolist()),"token_labels":json.dumps(token_labels), "decoded_entities": json.dumps(decoded_entities), "decoded_pred_relations":json.dumps(decoded_pred_relations)}