File size: 3,120 Bytes
a228fac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from config import Settings
from preprocess import Preprocessor
import ocr
from PIL import Image
from transformers import LiltForTokenClassification
import token_classification
import torch
from fastapi import FastAPI, UploadFile
from contextlib import asynccontextmanager
import json
import io
from models import LiLTRobertaLikeForRelationExtraction
config = {}

@asynccontextmanager
async def lifespan(app: FastAPI):
    settings = Settings()
    config['settings'] = settings
    config['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    config['vision_client'] = ocr.VisionClient(settings.GCV_AUTH)
    config['processor'] = Preprocessor(settings.TOKENIZER)
    config['ser_model'] = LiltForTokenClassification.from_pretrained(settings.SER_MODEL)
    config['re_model'] = LiLTRobertaLikeForRelationExtraction.from_pretrained(settings.RE_MODEL)
    yield
    # Clean up and release the resources
    config.clear()

app = FastAPI(lifespan=lifespan)

@app.post("/submit-doc")
async def ProcessDocument(file: UploadFile):
  tokenClassificationOutput = await LabelTokens(file)
  reOutput = ExtractRelations(tokenClassificationOutput)
  return reOutput

async def LabelTokens(file):
  content = await file.read()
  image = Image.open(io.BytesIO(content))
  ocr_df = config['vision_client'].ocr(content, image)
  input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
  token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
  return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "offset_mapping":offset_mapping, "attention_mask":attention_mask}

def ExtractRelations(tokenClassificationOutput):
  token_labels = tokenClassificationOutput['token_labels']
  input_ids = tokenClassificationOutput['input_ids']
  offset_mapping =  tokenClassificationOutput["offset_mapping"]
  attention_mask = tokenClassificationOutput["attention_mask"]
  bbox = tokenClassificationOutput["bbox"]

  entities = token_classification.createEntities(config['ser_model'], token_labels, input_ids, offset_mapping)
  
  config['re_model'].to(config['device'])
  entity_dict = {'start': [entity[0] for entity in entities], 'end': [entity[1] for entity in entities], 'label': [entity[3] for entity in entities]}
  relations = [{'start_index': [], 'end_index': [], 'head': [], 'tail': []}]
  with torch.no_grad():
    outputs = config['re_model'](input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, entities=[entity_dict], relations=relations)

  print(type(outputs.pred_relations[0]))
  print(type(entities))
  print(type(input_ids))
  print(type(bbox))
  print(type(token_labels))
  # "pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), 

  return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox.tolist()),"token_labels":json.dumps(token_labels)}