Commit
·
a228fac
0
Parent(s):
Initial commit
Browse files- .gitignore +5 -0
- README.md +6 -0
- config.py +9 -0
- download_model.ipynb +144 -0
- main.py +66 -0
- models.py +236 -0
- ocr.py +89 -0
- preprocess.py +111 -0
- requirements.txt +0 -0
- token_classification.py +36 -0
.gitignore
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
lilt-env/
|
2 |
+
.env
|
3 |
+
temp/
|
4 |
+
__pycache__/
|
5 |
+
models/
|
README.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
1. Create a virtualenv
|
2 |
+
`virtualenv lilt-env`
|
3 |
+
2. Install packages
|
4 |
+
`pip install -r requirements.txt`
|
5 |
+
3. Run the app
|
6 |
+
`uvicorn main:app --reload`
|
config.py
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
2 |
+
from pydantic import Field
|
3 |
+
|
4 |
+
class Settings(BaseSettings):
|
5 |
+
model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8')
|
6 |
+
GCV_AUTH: dict
|
7 |
+
SER_MODEL: str
|
8 |
+
TOKENIZER: str
|
9 |
+
RE_MODEL: str
|
download_model.ipynb
ADDED
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stderr",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"d:\\FYP\\lilt-app-without-fd\\lilt-env\\lib\\site-packages\\tqdm\\auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
13 |
+
" from .autonotebook import tqdm as notebook_tqdm\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"from transformers import LiltModel, AutoTokenizer, LiltForTokenClassification"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "markdown",
|
23 |
+
"metadata": {},
|
24 |
+
"source": [
|
25 |
+
"## Download tokenizer"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"execution_count": 2,
|
31 |
+
"metadata": {},
|
32 |
+
"outputs": [
|
33 |
+
{
|
34 |
+
"data": {
|
35 |
+
"text/plain": [
|
36 |
+
"('models/lilt-tokenizer\\\\tokenizer_config.json',\n",
|
37 |
+
" 'models/lilt-tokenizer\\\\special_tokens_map.json',\n",
|
38 |
+
" 'models/lilt-tokenizer\\\\tokenizer.json')"
|
39 |
+
]
|
40 |
+
},
|
41 |
+
"execution_count": 2,
|
42 |
+
"metadata": {},
|
43 |
+
"output_type": "execute_result"
|
44 |
+
}
|
45 |
+
],
|
46 |
+
"source": [
|
47 |
+
"TOKENIZER = 'nielsr/lilt-xlm-roberta-base'\n",
|
48 |
+
"tokenizer = AutoTokenizer.from_pretrained(TOKENIZER)\n",
|
49 |
+
"save_dir = 'models/lilt-tokenizer'\n",
|
50 |
+
"tokenizer.save_pretrained(save_dir)"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"cell_type": "markdown",
|
55 |
+
"metadata": {},
|
56 |
+
"source": [
|
57 |
+
"## Download and save token classification model"
|
58 |
+
]
|
59 |
+
},
|
60 |
+
{
|
61 |
+
"cell_type": "code",
|
62 |
+
"execution_count": 4,
|
63 |
+
"metadata": {},
|
64 |
+
"outputs": [],
|
65 |
+
"source": [
|
66 |
+
"# download the model\n",
|
67 |
+
"MODEL = \"pierreguillou/lilt-xlm-roberta-base-finetuned-funsd-iob-original\"\n",
|
68 |
+
"model = LiltForTokenClassification.from_pretrained(MODEL)\n",
|
69 |
+
"\n",
|
70 |
+
"# save the model\n",
|
71 |
+
"save_dir = \"models/lilt-ser-iob\"\n",
|
72 |
+
"model.save_pretrained(save_dir)"
|
73 |
+
]
|
74 |
+
},
|
75 |
+
{
|
76 |
+
"cell_type": "markdown",
|
77 |
+
"metadata": {},
|
78 |
+
"source": [
|
79 |
+
"## Download and save RE model"
|
80 |
+
]
|
81 |
+
},
|
82 |
+
{
|
83 |
+
"cell_type": "code",
|
84 |
+
"execution_count": 5,
|
85 |
+
"metadata": {},
|
86 |
+
"outputs": [
|
87 |
+
{
|
88 |
+
"name": "stderr",
|
89 |
+
"output_type": "stream",
|
90 |
+
"text": [
|
91 |
+
"Downloading config.json: 100%|██████████| 794/794 [00:00<00:00, 61.2kB/s]\n",
|
92 |
+
"d:\\FYP\\lilt-app-without-fd\\lilt-env\\lib\\site-packages\\huggingface_hub\\file_download.py:133: UserWarning: `huggingface_hub` cache-system uses symlinks by default to efficiently store duplicated files but your machine does not support them in C:\\Users\\Gihantha Kavishka\\.cache\\huggingface\\hub. Caching files will still work but in a degraded version that might require more space on your disk. This warning can be disabled by setting the `HF_HUB_DISABLE_SYMLINKS_WARNING` environment variable. For more details, see https://huggingface.co/docs/huggingface_hub/how-to-cache#limitations.\n",
|
93 |
+
"To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development\n",
|
94 |
+
" warnings.warn(message)\n",
|
95 |
+
"Downloading pytorch_model.bin: 100%|██████████| 1.15G/1.15G [08:10<00:00, 2.34MB/s]\n",
|
96 |
+
"Some weights of the model checkpoint at kavg/layoutxlm-finetuned-xfund-fr-re were not used when initializing LiltModel: ['extractor.rel_classifier.linear.weight', 'extractor.entity_emb.weight', 'extractor.ffnn_tail.0.weight', 'extractor.ffnn_tail.3.bias', 'extractor.ffnn_head.3.weight', 'extractor.ffnn_head.0.weight', 'extractor.ffnn_tail.0.bias', 'extractor.ffnn_head.3.bias', 'extractor.rel_classifier.bilinear.weight', 'extractor.rel_classifier.linear.bias', 'extractor.ffnn_head.0.bias', 'extractor.ffnn_tail.3.weight']\n",
|
97 |
+
"- This IS expected if you are initializing LiltModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
|
98 |
+
"- This IS NOT expected if you are initializing LiltModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
|
99 |
+
"Some weights of LiltModel were not initialized from the model checkpoint at kavg/layoutxlm-finetuned-xfund-fr-re and are newly initialized: ['lilt.pooler.dense.bias', 'lilt.pooler.dense.weight']\n",
|
100 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
101 |
+
]
|
102 |
+
}
|
103 |
+
],
|
104 |
+
"source": [
|
105 |
+
"# download the model\n",
|
106 |
+
"MODEL = 'kavg/layoutxlm-finetuned-xfund-fr-re'\n",
|
107 |
+
"model = LiltModel.from_pretrained(MODEL)\n",
|
108 |
+
"\n",
|
109 |
+
"# save the model\n",
|
110 |
+
"save_dir = \"models/lilt-re\"\n",
|
111 |
+
"model.save_pretrained(save_dir)"
|
112 |
+
]
|
113 |
+
},
|
114 |
+
{
|
115 |
+
"cell_type": "code",
|
116 |
+
"execution_count": null,
|
117 |
+
"metadata": {},
|
118 |
+
"outputs": [],
|
119 |
+
"source": []
|
120 |
+
}
|
121 |
+
],
|
122 |
+
"metadata": {
|
123 |
+
"kernelspec": {
|
124 |
+
"display_name": "lilt-env",
|
125 |
+
"language": "python",
|
126 |
+
"name": "python3"
|
127 |
+
},
|
128 |
+
"language_info": {
|
129 |
+
"codemirror_mode": {
|
130 |
+
"name": "ipython",
|
131 |
+
"version": 3
|
132 |
+
},
|
133 |
+
"file_extension": ".py",
|
134 |
+
"mimetype": "text/x-python",
|
135 |
+
"name": "python",
|
136 |
+
"nbconvert_exporter": "python",
|
137 |
+
"pygments_lexer": "ipython3",
|
138 |
+
"version": "3.7.8"
|
139 |
+
},
|
140 |
+
"orig_nbformat": 4
|
141 |
+
},
|
142 |
+
"nbformat": 4,
|
143 |
+
"nbformat_minor": 2
|
144 |
+
}
|
main.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from config import Settings
|
2 |
+
from preprocess import Preprocessor
|
3 |
+
import ocr
|
4 |
+
from PIL import Image
|
5 |
+
from transformers import LiltForTokenClassification
|
6 |
+
import token_classification
|
7 |
+
import torch
|
8 |
+
from fastapi import FastAPI, UploadFile
|
9 |
+
from contextlib import asynccontextmanager
|
10 |
+
import json
|
11 |
+
import io
|
12 |
+
from models import LiLTRobertaLikeForRelationExtraction
|
13 |
+
config = {}
|
14 |
+
|
15 |
+
@asynccontextmanager
|
16 |
+
async def lifespan(app: FastAPI):
|
17 |
+
settings = Settings()
|
18 |
+
config['settings'] = settings
|
19 |
+
config['device'] = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
20 |
+
config['vision_client'] = ocr.VisionClient(settings.GCV_AUTH)
|
21 |
+
config['processor'] = Preprocessor(settings.TOKENIZER)
|
22 |
+
config['ser_model'] = LiltForTokenClassification.from_pretrained(settings.SER_MODEL)
|
23 |
+
config['re_model'] = LiLTRobertaLikeForRelationExtraction.from_pretrained(settings.RE_MODEL)
|
24 |
+
yield
|
25 |
+
# Clean up and release the resources
|
26 |
+
config.clear()
|
27 |
+
|
28 |
+
app = FastAPI(lifespan=lifespan)
|
29 |
+
|
30 |
+
@app.post("/submit-doc")
|
31 |
+
async def ProcessDocument(file: UploadFile):
|
32 |
+
tokenClassificationOutput = await LabelTokens(file)
|
33 |
+
reOutput = ExtractRelations(tokenClassificationOutput)
|
34 |
+
return reOutput
|
35 |
+
|
36 |
+
async def LabelTokens(file):
|
37 |
+
content = await file.read()
|
38 |
+
image = Image.open(io.BytesIO(content))
|
39 |
+
ocr_df = config['vision_client'].ocr(content, image)
|
40 |
+
input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping = config['processor'].process(ocr_df, image = image)
|
41 |
+
token_labels = token_classification.classifyTokens(config['ser_model'], input_ids, attention_mask, bbox, offset_mapping)
|
42 |
+
return {"token_labels": token_labels, "input_ids": input_ids, "bbox":bbox, "offset_mapping":offset_mapping, "attention_mask":attention_mask}
|
43 |
+
|
44 |
+
def ExtractRelations(tokenClassificationOutput):
|
45 |
+
token_labels = tokenClassificationOutput['token_labels']
|
46 |
+
input_ids = tokenClassificationOutput['input_ids']
|
47 |
+
offset_mapping = tokenClassificationOutput["offset_mapping"]
|
48 |
+
attention_mask = tokenClassificationOutput["attention_mask"]
|
49 |
+
bbox = tokenClassificationOutput["bbox"]
|
50 |
+
|
51 |
+
entities = token_classification.createEntities(config['ser_model'], token_labels, input_ids, offset_mapping)
|
52 |
+
|
53 |
+
config['re_model'].to(config['device'])
|
54 |
+
entity_dict = {'start': [entity[0] for entity in entities], 'end': [entity[1] for entity in entities], 'label': [entity[3] for entity in entities]}
|
55 |
+
relations = [{'start_index': [], 'end_index': [], 'head': [], 'tail': []}]
|
56 |
+
with torch.no_grad():
|
57 |
+
outputs = config['re_model'](input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, entities=[entity_dict], relations=relations)
|
58 |
+
|
59 |
+
print(type(outputs.pred_relations[0]))
|
60 |
+
print(type(entities))
|
61 |
+
print(type(input_ids))
|
62 |
+
print(type(bbox))
|
63 |
+
print(type(token_labels))
|
64 |
+
# "pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()),
|
65 |
+
|
66 |
+
return {"pred_relations":json.dumps(outputs.pred_relations[0]), "entities":json.dumps(entities), "input_ids": json.dumps(input_ids.tolist()), "bboxes": json.dumps(bbox.tolist()),"token_labels":json.dumps(token_labels)}
|
models.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import LiltPreTrainedModel, LiltModel
|
2 |
+
import copy
|
3 |
+
import torch
|
4 |
+
from torch import nn
|
5 |
+
from torch.nn import CrossEntropyLoss
|
6 |
+
from dataclasses import dataclass
|
7 |
+
from typing import Dict, Optional, Tuple
|
8 |
+
from transformers.utils import ModelOutput
|
9 |
+
|
10 |
+
class BiaffineAttention(torch.nn.Module):
|
11 |
+
"""Implements a biaffine attention operator for binary relation classification.
|
12 |
+
|
13 |
+
PyTorch implementation of the biaffine attention operator from "End-to-end neural relation
|
14 |
+
extraction using deep biaffine attention" (https://arxiv.org/abs/1812.11275) which can be used
|
15 |
+
as a classifier for binary relation classification.
|
16 |
+
|
17 |
+
Args:
|
18 |
+
in_features (int): The size of the feature dimension of the inputs.
|
19 |
+
out_features (int): The size of the feature dimension of the output.
|
20 |
+
|
21 |
+
Shape:
|
22 |
+
- x_1: `(N, *, in_features)` where `N` is the batch dimension and `*` means any number of
|
23 |
+
additional dimensisons.
|
24 |
+
- x_2: `(N, *, in_features)`, where `N` is the batch dimension and `*` means any number of
|
25 |
+
additional dimensions.
|
26 |
+
- Output: `(N, *, out_features)`, where `N` is the batch dimension and `*` means any number
|
27 |
+
of additional dimensions.
|
28 |
+
|
29 |
+
Examples:
|
30 |
+
>>> batch_size, in_features, out_features = 32, 100, 4
|
31 |
+
>>> biaffine_attention = BiaffineAttention(in_features, out_features)
|
32 |
+
>>> x_1 = torch.randn(batch_size, in_features)
|
33 |
+
>>> x_2 = torch.randn(batch_size, in_features)
|
34 |
+
>>> output = biaffine_attention(x_1, x_2)
|
35 |
+
>>> print(output.size())
|
36 |
+
torch.Size([32, 4])
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, in_features, out_features):
|
40 |
+
super(BiaffineAttention, self).__init__()
|
41 |
+
|
42 |
+
self.in_features = in_features
|
43 |
+
self.out_features = out_features
|
44 |
+
|
45 |
+
self.bilinear = torch.nn.Bilinear(in_features, in_features, out_features, bias=False)
|
46 |
+
self.linear = torch.nn.Linear(2 * in_features, out_features, bias=True)
|
47 |
+
|
48 |
+
self.reset_parameters()
|
49 |
+
|
50 |
+
def forward(self, x_1, x_2):
|
51 |
+
return self.bilinear(x_1, x_2) + self.linear(torch.cat((x_1, x_2), dim=-1))
|
52 |
+
|
53 |
+
def reset_parameters(self):
|
54 |
+
self.bilinear.reset_parameters()
|
55 |
+
self.linear.reset_parameters()
|
56 |
+
|
57 |
+
|
58 |
+
class REDecoder(nn.Module):
|
59 |
+
def __init__(self, config, input_size):
|
60 |
+
super().__init__()
|
61 |
+
self.entity_emb = nn.Embedding(3, input_size, scale_grad_by_freq=True)
|
62 |
+
projection = nn.Sequential(
|
63 |
+
nn.Linear(input_size * 2, config.hidden_size),
|
64 |
+
nn.ReLU(),
|
65 |
+
nn.Dropout(config.hidden_dropout_prob),
|
66 |
+
nn.Linear(config.hidden_size, config.hidden_size // 2),
|
67 |
+
nn.ReLU(),
|
68 |
+
nn.Dropout(config.hidden_dropout_prob),
|
69 |
+
)
|
70 |
+
self.ffnn_head = copy.deepcopy(projection)
|
71 |
+
self.ffnn_tail = copy.deepcopy(projection)
|
72 |
+
self.rel_classifier = BiaffineAttention(config.hidden_size // 2, 2)
|
73 |
+
self.loss_fct = CrossEntropyLoss()
|
74 |
+
|
75 |
+
def build_relation(self, relations, entities):
|
76 |
+
batch_size = len(relations)
|
77 |
+
new_relations = []
|
78 |
+
for b in range(batch_size):
|
79 |
+
if len(entities[b]["start"]) <= 2:
|
80 |
+
entities[b] = {"end": [1, 1], "label": [0, 0], "start": [0, 0]}
|
81 |
+
all_possible_relations = set(
|
82 |
+
[
|
83 |
+
(i, j)
|
84 |
+
for i in range(len(entities[b]["label"]))
|
85 |
+
for j in range(len(entities[b]["label"]))
|
86 |
+
if entities[b]["label"][i] == 1 and entities[b]["label"][j] == 2
|
87 |
+
]
|
88 |
+
)
|
89 |
+
if len(all_possible_relations) == 0:
|
90 |
+
all_possible_relations = set([(0, 1)])
|
91 |
+
positive_relations = set(list(zip(relations[b]["head"], relations[b]["tail"])))
|
92 |
+
negative_relations = all_possible_relations - positive_relations
|
93 |
+
positive_relations = set([i for i in positive_relations if i in all_possible_relations])
|
94 |
+
reordered_relations = list(positive_relations) + list(negative_relations)
|
95 |
+
relation_per_doc = {"head": [], "tail": [], "label": []}
|
96 |
+
relation_per_doc["head"] = [i[0] for i in reordered_relations]
|
97 |
+
relation_per_doc["tail"] = [i[1] for i in reordered_relations]
|
98 |
+
relation_per_doc["label"] = [1] * len(positive_relations) + [0] * (
|
99 |
+
len(reordered_relations) - len(positive_relations)
|
100 |
+
)
|
101 |
+
assert len(relation_per_doc["head"]) != 0
|
102 |
+
new_relations.append(relation_per_doc)
|
103 |
+
return new_relations, entities
|
104 |
+
|
105 |
+
def get_predicted_relations(self, logits, relations, entities):
|
106 |
+
pred_relations = []
|
107 |
+
for i, pred_label in enumerate(logits.argmax(-1)):
|
108 |
+
if pred_label != 1:
|
109 |
+
continue
|
110 |
+
rel = {}
|
111 |
+
rel["head_id"] = relations["head"][i]
|
112 |
+
rel["head"] = (entities["start"][rel["head_id"]], entities["end"][rel["head_id"]])
|
113 |
+
rel["head_type"] = entities["label"][rel["head_id"]]
|
114 |
+
|
115 |
+
rel["tail_id"] = relations["tail"][i]
|
116 |
+
rel["tail"] = (entities["start"][rel["tail_id"]], entities["end"][rel["tail_id"]])
|
117 |
+
rel["tail_type"] = entities["label"][rel["tail_id"]]
|
118 |
+
rel["type"] = 1
|
119 |
+
pred_relations.append(rel)
|
120 |
+
return pred_relations
|
121 |
+
|
122 |
+
def forward(self, hidden_states, entities, relations):
|
123 |
+
batch_size, max_n_words, context_dim = hidden_states.size()
|
124 |
+
device = hidden_states.device
|
125 |
+
relations, entities = self.build_relation(relations, entities)
|
126 |
+
loss = 0
|
127 |
+
all_pred_relations = []
|
128 |
+
all_logits = []
|
129 |
+
all_labels = []
|
130 |
+
|
131 |
+
for b in range(batch_size):
|
132 |
+
head_entities = torch.tensor(relations[b]["head"], device=device)
|
133 |
+
tail_entities = torch.tensor(relations[b]["tail"], device=device)
|
134 |
+
relation_labels = torch.tensor(relations[b]["label"], device=device)
|
135 |
+
entities_start_index = torch.tensor(entities[b]["start"], device=device)
|
136 |
+
entities_labels = torch.tensor(entities[b]["label"], device=device)
|
137 |
+
head_index = entities_start_index[head_entities]
|
138 |
+
head_label = entities_labels[head_entities]
|
139 |
+
head_label_repr = self.entity_emb(head_label)
|
140 |
+
|
141 |
+
tail_index = entities_start_index[tail_entities]
|
142 |
+
tail_label = entities_labels[tail_entities]
|
143 |
+
tail_label_repr = self.entity_emb(tail_label)
|
144 |
+
|
145 |
+
head_repr = torch.cat(
|
146 |
+
(hidden_states[b][head_index], head_label_repr),
|
147 |
+
dim=-1,
|
148 |
+
)
|
149 |
+
tail_repr = torch.cat(
|
150 |
+
(hidden_states[b][tail_index], tail_label_repr),
|
151 |
+
dim=-1,
|
152 |
+
)
|
153 |
+
heads = self.ffnn_head(head_repr)
|
154 |
+
tails = self.ffnn_tail(tail_repr)
|
155 |
+
logits = self.rel_classifier(heads, tails)
|
156 |
+
pred_relations = self.get_predicted_relations(logits, relations[b], entities[b])
|
157 |
+
all_pred_relations.append(pred_relations)
|
158 |
+
all_logits.append(logits)
|
159 |
+
all_labels.append(relation_labels)
|
160 |
+
all_logits = torch.cat(all_logits, 0)
|
161 |
+
all_labels = torch.cat(all_labels, 0)
|
162 |
+
loss = self.loss_fct(all_logits, all_labels)
|
163 |
+
return loss, all_pred_relations
|
164 |
+
|
165 |
+
|
166 |
+
@dataclass
|
167 |
+
class ReOutput(ModelOutput):
|
168 |
+
loss: Optional[torch.FloatTensor] = None
|
169 |
+
logits: torch.FloatTensor = None
|
170 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
171 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
172 |
+
entities: Optional[Dict] = None
|
173 |
+
relations: Optional[Dict] = None
|
174 |
+
pred_relations: Optional[Dict] = None
|
175 |
+
|
176 |
+
class REHead(nn.Module):
|
177 |
+
def __init__(self, config):
|
178 |
+
super().__init__()
|
179 |
+
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
180 |
+
self.extractor = REDecoder(config, config.hidden_size)
|
181 |
+
|
182 |
+
def forward(self,sequence_output, entities, relations):
|
183 |
+
sequence_output = self.dropout(sequence_output)
|
184 |
+
loss, pred_relations = self.extractor(sequence_output, entities, relations)
|
185 |
+
return ReOutput(
|
186 |
+
loss=loss,
|
187 |
+
entities=entities,
|
188 |
+
relations=relations,
|
189 |
+
pred_relations=pred_relations,
|
190 |
+
)
|
191 |
+
|
192 |
+
class LiLTRobertaLikeForRelationExtraction(LiltPreTrainedModel):
|
193 |
+
_keys_to_ignore_on_load_unexpected = [r"pooler"]
|
194 |
+
_keys_to_ignore_on_load_missing = [r"position_ids"]
|
195 |
+
def __init__(self, config):
|
196 |
+
super().__init__(config)
|
197 |
+
|
198 |
+
self.lilt = LiltModel(config, add_pooling_layer=False)
|
199 |
+
self.rehead = REHead(config)
|
200 |
+
self.init_weights()
|
201 |
+
|
202 |
+
|
203 |
+
def forward(
|
204 |
+
self,
|
205 |
+
input_ids=None,
|
206 |
+
bbox=None,
|
207 |
+
attention_mask=None,
|
208 |
+
token_type_ids=None,
|
209 |
+
position_ids=None,
|
210 |
+
head_mask=None,
|
211 |
+
inputs_embeds=None,
|
212 |
+
labels=None,
|
213 |
+
output_attentions=None,
|
214 |
+
output_hidden_states=None,
|
215 |
+
return_dict=None,
|
216 |
+
entities=None,
|
217 |
+
relations=None,
|
218 |
+
):
|
219 |
+
|
220 |
+
outputs = self.lilt(
|
221 |
+
input_ids,
|
222 |
+
bbox=bbox,
|
223 |
+
attention_mask=attention_mask,
|
224 |
+
token_type_ids=token_type_ids,
|
225 |
+
position_ids=position_ids,
|
226 |
+
head_mask=head_mask,
|
227 |
+
inputs_embeds=inputs_embeds,
|
228 |
+
output_attentions=output_attentions,
|
229 |
+
output_hidden_states=output_hidden_states,
|
230 |
+
return_dict=return_dict,
|
231 |
+
)
|
232 |
+
|
233 |
+
sequence_output = outputs[0]
|
234 |
+
|
235 |
+
re_output = self.rehead(sequence_output, entities, relations)
|
236 |
+
return re_output
|
ocr.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from google.cloud import vision
|
2 |
+
from google.oauth2 import service_account
|
3 |
+
from google.protobuf.json_format import MessageToJson
|
4 |
+
import pandas as pd
|
5 |
+
import json
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
import io
|
9 |
+
|
10 |
+
image_ext = ("*.jpg", "*.jpeg", "*.png")
|
11 |
+
|
12 |
+
class VisionClient:
|
13 |
+
def __init__(self, auth):
|
14 |
+
credentials = service_account.Credentials.from_service_account_info(
|
15 |
+
auth
|
16 |
+
)
|
17 |
+
self.client = vision.ImageAnnotatorClient(credentials=credentials)
|
18 |
+
|
19 |
+
def send_request(self, image):
|
20 |
+
try:
|
21 |
+
image = vision.Image(content=image)
|
22 |
+
except ValueError as e:
|
23 |
+
print("Image could not be read")
|
24 |
+
return
|
25 |
+
response = self.client.document_text_detection(image, timeout=10)
|
26 |
+
return response
|
27 |
+
|
28 |
+
def get_response(self, content):
|
29 |
+
try:
|
30 |
+
resp_js = self.send_request(content)
|
31 |
+
except Exception as e:
|
32 |
+
print("OCR request failed. Reason : {}".format(e))
|
33 |
+
|
34 |
+
return resp_js
|
35 |
+
|
36 |
+
def post_process(self, resp_js):
|
37 |
+
boxObjects = []
|
38 |
+
for i in range(1, len(resp_js.text_annotations)):
|
39 |
+
# We need to do that because vision sometimes reverse the left and right coords so then we have negative
|
40 |
+
# width which causes problems when drawing link buttons
|
41 |
+
obj = resp_js
|
42 |
+
if obj.text_annotations[i].bounding_poly.vertices[1].x > obj.text_annotations[i].bounding_poly.vertices[3].x:
|
43 |
+
leftX = obj.text_annotations[i].bounding_poly.vertices[3].x
|
44 |
+
else:
|
45 |
+
leftX = obj.text_annotations[i].bounding_poly.vertices[1].x
|
46 |
+
|
47 |
+
if obj.text_annotations[i].bounding_poly.vertices[1].x > obj.text_annotations[i].bounding_poly.vertices[3].x:
|
48 |
+
rightX = obj.text_annotations[i].bounding_poly.vertices[1].x
|
49 |
+
else:
|
50 |
+
rightX = obj.text_annotations[i].bounding_poly.vertices[3].x
|
51 |
+
|
52 |
+
boxObjects.append({
|
53 |
+
"id": i-1,
|
54 |
+
"text": obj.text_annotations[i].description,
|
55 |
+
"left": leftX,
|
56 |
+
"width": rightX - leftX,
|
57 |
+
"top": obj.text_annotations[i].bounding_poly.vertices[1].y,
|
58 |
+
"height":obj.text_annotations[i].bounding_poly.vertices[3].y - obj.text_annotations[i].bounding_poly.vertices[1].y
|
59 |
+
})
|
60 |
+
|
61 |
+
return boxObjects
|
62 |
+
|
63 |
+
def convert_to_df(self, boxObjects, image):
|
64 |
+
ocr_df = pd.DataFrame(boxObjects)
|
65 |
+
|
66 |
+
# ocr_df = ocr_df.sort_values(by=['top', 'left'], ascending=True).reset_index(drop=True)
|
67 |
+
width, height = image.size
|
68 |
+
w_scale = 1000/width
|
69 |
+
h_scale = 1000/height
|
70 |
+
|
71 |
+
ocr_df = ocr_df.dropna() \
|
72 |
+
.assign(left_scaled = ocr_df.left*w_scale,
|
73 |
+
width_scaled = ocr_df.width*w_scale,
|
74 |
+
top_scaled = ocr_df.top*h_scale,
|
75 |
+
height_scaled = ocr_df.height*h_scale,
|
76 |
+
right_scaled = lambda x: x.left_scaled + x.width_scaled,
|
77 |
+
bottom_scaled = lambda x: x.top_scaled + x.height_scaled)
|
78 |
+
|
79 |
+
float_cols = ocr_df.select_dtypes('float').columns
|
80 |
+
ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)
|
81 |
+
ocr_df = ocr_df.replace(r'^\s*$', np.nan, regex=True)
|
82 |
+
ocr_df = ocr_df.dropna().reset_index(drop=True)
|
83 |
+
return ocr_df
|
84 |
+
|
85 |
+
def ocr(self, content, image):
|
86 |
+
resp_js = self.get_response(content)
|
87 |
+
boxObjects = self.post_process(resp_js)
|
88 |
+
ocr_df = self.convert_to_df(boxObjects, image)
|
89 |
+
return ocr_df
|
preprocess.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
|
4 |
+
# class to turn the keys of a dict into attributes (thanks Stackoverflow)
|
5 |
+
class AttrDict(dict):
|
6 |
+
def __init__(self, *args, **kwargs):
|
7 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
8 |
+
self.__dict__ = self
|
9 |
+
|
10 |
+
class Preprocessor():
|
11 |
+
def __init__(self, tokenizer):
|
12 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer)
|
13 |
+
self.argsdict = {'max_seq_length': 512}
|
14 |
+
self.args = AttrDict(self.argsdict)
|
15 |
+
|
16 |
+
def get_boxes(self, ocr_df, image):
|
17 |
+
words = list(ocr_df.text)
|
18 |
+
coordinates = ocr_df[['left', 'top', 'width', 'height']]
|
19 |
+
actual_boxes = []
|
20 |
+
width, height = image.size
|
21 |
+
for idx, row in coordinates.iterrows():
|
22 |
+
x, y, w, h = tuple(row) # the row comes in (left, top, width, height) format
|
23 |
+
actual_box = [x, y, x+w, y+h] # we turn it into (left, top, left+widght, top+height) to get the actual box
|
24 |
+
actual_boxes.append(actual_box)
|
25 |
+
|
26 |
+
def normalize_box(box, width, height):
|
27 |
+
return [
|
28 |
+
int(1000 * (box[0] / width)),
|
29 |
+
int(1000 * (box[1] / height)),
|
30 |
+
int(1000 * (box[2] / width)),
|
31 |
+
int(1000 * (box[3] / height)),
|
32 |
+
]
|
33 |
+
|
34 |
+
boxes = []
|
35 |
+
for box in actual_boxes:
|
36 |
+
boxes.append(normalize_box(box, width, height))
|
37 |
+
|
38 |
+
return words, boxes, actual_boxes
|
39 |
+
|
40 |
+
def convert_example_to_features(self, image, words, boxes, actual_boxes, cls_token_box=[0, 0, 0, 0],
|
41 |
+
sep_token_box=[1000, 1000, 1000, 1000],
|
42 |
+
pad_token_box=[0, 0, 0, 0]):
|
43 |
+
width, height = image.size
|
44 |
+
|
45 |
+
tokens = []
|
46 |
+
token_boxes = []
|
47 |
+
actual_bboxes = [] # we use an extra b because actual_boxes is already used
|
48 |
+
token_actual_boxes = []
|
49 |
+
offset_mapping = []
|
50 |
+
for word, box, actual_bbox in zip(words, boxes, actual_boxes):
|
51 |
+
word_tokens = self.tokenizer.tokenize(word)
|
52 |
+
mapping = self.tokenizer(word, return_offsets_mapping=True).offset_mapping
|
53 |
+
offset_mapping.extend(mapping)
|
54 |
+
tokens.extend(word_tokens)
|
55 |
+
token_boxes.extend([box] * len(word_tokens))
|
56 |
+
actual_bboxes.extend([actual_bbox] * len(word_tokens))
|
57 |
+
token_actual_boxes.extend([actual_bbox] * len(word_tokens))
|
58 |
+
|
59 |
+
# Truncation: account for [CLS] and [SEP] with "- 2".
|
60 |
+
special_tokens_count = 2
|
61 |
+
if len(tokens) > self.args.max_seq_length - special_tokens_count:
|
62 |
+
tokens = tokens[: (self.args.max_seq_length - special_tokens_count)]
|
63 |
+
token_boxes = token_boxes[: (self.args.max_seq_length - special_tokens_count)]
|
64 |
+
actual_bboxes = actual_bboxes[: (self.args.max_seq_length - special_tokens_count)]
|
65 |
+
token_actual_boxes = token_actual_boxes[: (self.args.max_seq_length - special_tokens_count)]
|
66 |
+
|
67 |
+
# add [SEP] token, with corresponding token boxes and actual boxes
|
68 |
+
tokens += [self.tokenizer.sep_token]
|
69 |
+
token_boxes += [sep_token_box]
|
70 |
+
actual_bboxes += [[0, 0, width, height]]
|
71 |
+
token_actual_boxes += [[0, 0, width, height]]
|
72 |
+
|
73 |
+
segment_ids = [0] * len(tokens)
|
74 |
+
|
75 |
+
# next: [CLS] token
|
76 |
+
tokens = [self.tokenizer.cls_token] + tokens
|
77 |
+
token_boxes = [cls_token_box] + token_boxes
|
78 |
+
actual_bboxes = [[0, 0, width, height]] + actual_bboxes
|
79 |
+
token_actual_boxes = [[0, 0, width, height]] + token_actual_boxes
|
80 |
+
segment_ids = [1] + segment_ids
|
81 |
+
|
82 |
+
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
|
83 |
+
|
84 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real
|
85 |
+
# tokens are attended to.
|
86 |
+
input_mask = [1] * len(input_ids)
|
87 |
+
|
88 |
+
# Zero-pad up to the sequence length.
|
89 |
+
padding_length = self.args.max_seq_length - len(input_ids)
|
90 |
+
input_ids += [self.tokenizer.pad_token_id] * padding_length
|
91 |
+
input_mask += [0] * padding_length
|
92 |
+
segment_ids += [self.tokenizer.pad_token_id] * padding_length
|
93 |
+
token_boxes += [pad_token_box] * padding_length
|
94 |
+
token_actual_boxes += [pad_token_box] * padding_length
|
95 |
+
|
96 |
+
assert len(input_ids) == self.args.max_seq_length
|
97 |
+
assert len(input_mask) == self.args.max_seq_length
|
98 |
+
assert len(segment_ids) == self.args.max_seq_length
|
99 |
+
assert len(token_boxes) == self.args.max_seq_length
|
100 |
+
assert len(token_actual_boxes) == self.args.max_seq_length
|
101 |
+
|
102 |
+
return input_ids, input_mask, segment_ids, token_boxes, token_actual_boxes, offset_mapping
|
103 |
+
|
104 |
+
def process(self, ocr_df, image):
|
105 |
+
words, boxes, actual_boxes = self.get_boxes(ocr_df, image)
|
106 |
+
input_ids, input_mask, segment_ids, token_boxes, token_actual_boxes, offset_mapping = self.convert_example_to_features(image=image, words=words, boxes=boxes, actual_boxes=actual_boxes)
|
107 |
+
input_ids = torch.tensor(input_ids).unsqueeze(0)
|
108 |
+
attention_mask = torch.tensor(input_mask).unsqueeze(0)
|
109 |
+
token_type_ids = torch.tensor(segment_ids).unsqueeze(0)
|
110 |
+
bbox = torch.tensor(token_boxes).unsqueeze(0)
|
111 |
+
return input_ids, attention_mask, token_type_ids, bbox, token_actual_boxes, offset_mapping
|
requirements.txt
ADDED
Binary file (3.27 kB). View file
|
|
token_classification.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
def classifyTokens(model, input_ids, attention_mask, bbox, offset_mapping):
|
4 |
+
outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask)
|
5 |
+
# take argmax on last dimension to get predicted class ID per token
|
6 |
+
predictions = outputs.logits.argmax(-1).squeeze().tolist()
|
7 |
+
return predictions
|
8 |
+
|
9 |
+
def createEntities(model, predictions, input_ids, offset_mapping):
|
10 |
+
# we're only interested in tokens which aren't subwords
|
11 |
+
# we'll use the offset mapping for that
|
12 |
+
offset_mapping = np.array(offset_mapping)
|
13 |
+
is_subword = np.array(offset_mapping.squeeze().tolist())[:,0] != 0
|
14 |
+
|
15 |
+
id2label = {"HEADER":0, "QUESTION":1, "ANSWER":2}
|
16 |
+
|
17 |
+
# finally, store recognized "question" and "answer" entities in a list
|
18 |
+
entities = []
|
19 |
+
current_entity = None
|
20 |
+
start = None
|
21 |
+
end = None
|
22 |
+
|
23 |
+
for idx, (id, pred) in enumerate(zip(input_ids[0].tolist(), predictions)):
|
24 |
+
if not is_subword[idx]:
|
25 |
+
predicted_label = model.config.id2label[pred]
|
26 |
+
if predicted_label.startswith("B") and current_entity is None:
|
27 |
+
# means we're at the start of a new entity
|
28 |
+
current_entity = predicted_label.replace("B-", "")
|
29 |
+
start = idx
|
30 |
+
if current_entity is not None and current_entity not in predicted_label:
|
31 |
+
# means we're at the end of a new entity
|
32 |
+
end = idx
|
33 |
+
entities.append((start, end, current_entity, id2label[current_entity]))
|
34 |
+
current_entity = None
|
35 |
+
|
36 |
+
return entities
|